From 5f3ae51bd781db258ab916a66c06042223b82fb5 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 13 Nov 2024 22:38:11 -0800 Subject: [PATCH] New trendline ppl command (WMA) (#872) * WMA implementation Signed-off-by: Andy Kwok * Update test cases Signed-off-by: Andy Kwok * Update tests Signed-off-by: Andy Kwok * Refactor code Signed-off-by: Andy Kwok * Addres comments Signed-off-by: Andy Kwok * Update doc Signed-off-by: Andy Kwok * Update example Signed-off-by: Andy Kwok * Update readme Signed-off-by: Andy Kwok * Update scalafmt Signed-off-by: Andy Kwok * Update grammar rule Signed-off-by: Andy Kwok * Address review comments Signed-off-by: Andy Kwok * Address review comments Signed-off-by: Andy Kwok --------- Signed-off-by: Andy Kwok --- DEVELOPER_GUIDE.md | 12 + docs/ppl-lang/PPL-Example-Commands.md | 1 + docs/ppl-lang/ppl-trendline-command.md | 64 ++++- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 268 +++++++++++++++++- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../opensearch/sql/ast/tree/Trendline.java | 2 +- .../function/BuiltinFunctionName.java | 1 - .../sql/ppl/utils/TrendlineCatalystUtils.java | 193 +++++++++++-- ...nTrendlineCommandTranslatorTestSuite.scala | 148 +++++++++- 9 files changed, 655 insertions(+), 35 deletions(-) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index bb8f697ec..834a2a201 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -11,6 +11,11 @@ To execute the unit tests, run the following command: ``` sbt test ``` +To run a specific unit test in SBT, use the testOnly command with the full path of the test class: +``` +sbt "; project pplSparkIntegration; test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite" +``` + ## Integration Test The integration test is defined in the `integration` directory of the project. The integration tests will automatically trigger unit tests and will only run if all unit tests pass. If you want to run the integration test for the project, you can do so by running the following command: @@ -23,6 +28,13 @@ If you get integration test failures with error message "Previous attempts to fi 3. Run `sudo ln -s $HOME/.docker/desktop/docker.sock /var/run/docker.sock` or `sudo ln -s $HOME/.docker/run/docker.sock /var/run/docker.sock` 4. If you use Docker Desktop, as an alternative of `3`, check mark the "Allow the default Docker socket to be used (requires password)" in advanced settings of Docker Desktop. +Running only a selected set of integration test suites is possible with the following command: +``` +sbt "; project integtest; it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite" +``` +This command runs only the specified test suite within the integtest submodule. + + ### AWS Integration Test The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 851531b5b..7766c3b50 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -65,6 +65,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | trendline sma(2, temperature) as temp_trend` +- `source = table | trendline sort timestamp wma(2, temperature) as temp_trend` #### **IP related queries** [See additional command details](functions/ppl-ip.md) diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md index 393a9dd59..b466e2e8f 100644 --- a/docs/ppl-lang/ppl-trendline-command.md +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -3,8 +3,7 @@ **Description** Using ``trendline`` command to calculate moving averages of fields. - -### Syntax +### Syntax - SMA (Simple Moving Average) `TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. @@ -13,8 +12,6 @@ Using ``trendline`` command to calculate moving averages of fields. * field: mandatory. the name of the field the moving average should be calculated for. * alias: optional. the name of the resulting column containing the moving average. -And the moment only the Simple Moving Average (SMA) type is supported. - It is calculated like f[i]: The value of field 'f' in the i-th data-point @@ -23,7 +20,7 @@ It is calculated like SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -### Example 1: Calculate simple moving average for a timeseries of temperatures +#### Example 1: Calculate simple moving average for a timeseries of temperatures The example calculates the simple moving average over temperatures using two datapoints. @@ -41,7 +38,7 @@ PPL query: | 15| 258|2023-04-06 17:07:...| 14.5| +-----------+---------+--------------------+----------+ -### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. @@ -58,3 +55,58 @@ PPL query: | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| +-----------+---------+--------------------+------------+------------------+ + + +### Syntax - WMA (Weighted Moving Average) +`TRENDLINE sort <[+|-] sort-field> WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory. this field specifies the ordering of data poients when calculating the nth_value aggregation. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving averag should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +#### Example 1: Calculate weighted moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sort timestamp wma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.6| + | 14| 257|2023-04-06 17:07:...| 13.6| + | 15| 258|2023-04-06 17:07:...| 14.6| + +-----------+---------+--------------------+----------+ + +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id wma(2, temperature) as temp_trend_2 wma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.3| NULL| + | 13| 256|2023-04-06 17:07:...| 13.3| 13.6| + | 12| 1492|2023-04-06 17:07:...| 12.3| 12.6| + | 12| 1492|2023-04-06 17:07:...| 12.0| 12.16| + +-----------+---------+--------------------+------------+------------------+ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index bc4463537..9a8379288 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -5,9 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -244,4 +249,265 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test trendline wma command with sort field and without alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command with sort field and with alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) as trendline_alias + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "trendline_alias"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline wma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "two_points_wma", + "three_points_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 23.333333333333332, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 28.333333333333332, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 56.666666666666664, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS two_points_wma#247, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma + | """.stripMargin) + + // Compare the headers + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 46.666666666666664), + Row("Hello", 60, 56.666666666666664), + Row("Jake", 140, 113.33333333333333)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + getNthValueAggregation("doubled_age", "age", 1, -1), + getNthValueAggregation("doubled_age", "age", 2, -1)) + val wmaExpression = Divide(dividend, Literal(3)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val doubledAged = Alias( + UnresolvedFunction( + seq("*"), + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false), + "doubled_age")() + val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) + val sortedTable = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("doubled_age"), + UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable)) + + /** + * 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, (( + * ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog, + * default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + } + + test("test invalid wma command with negative dataPoint value") { + val exception = intercept[ParseException](sql(s""" + | source = $testTable | trendline sort + age wma(-3, age) + | """.stripMargin)) + assert(exception.getMessage contains "[PARSE_SYNTAX_ERROR] Syntax error") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 02818c1fb..f3c6acda9 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -96,6 +96,7 @@ NULLS: 'NULLS'; //TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 9fa1ae81d..d08e89e3b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -62,6 +62,6 @@ public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dat } public enum TrendlineType { - SMA + SMA, WMA } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f039bf47f..86970cefb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -292,7 +292,6 @@ public enum BuiltinFunctionName { MULTIMATCHQUERY(FunctionName.of("multimatchquery")), WILDCARDQUERY(FunctionName.of("wildcardquery")), WILDCARD_QUERY(FunctionName.of("wildcard_query")), - COALESCE(FunctionName.of("coalesce")); private FunctionName name; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 67603ccc7..647f4542e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -5,31 +5,40 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.*; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.collection.mutable.Seq; import scala.Option; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBufferConverter; public interface TrendlineCatalystUtils { - static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, Optional sortField, CatalystPlanContext context) { return computations.stream() - .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, sortField, context)) .collect(Collectors.toList()); } - static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, Optional sortField, CatalystPlanContext context) { + //window lower boundary expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); Expression windowLowerBoundary = context.popNamedParseExpressions().get(); @@ -40,26 +49,28 @@ static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expre seq(), new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); - if (node.getComputationType() == Trendline.TrendlineType.SMA) { - //calculate avg value of the data field - expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); - Expression avgFunction = context.popNamedParseExpressions().get(); - - //sma window - WindowExpression sma = new WindowExpression( - avgFunction, - windowDefinition); - - CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); - - return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, - node.getAlias(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList())); - } else { - throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + switch (node.getComputationType()) { + case SMA: + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return getAlias(node.getAlias(), smaOrNull); + case WMA: + if (sortField.isPresent()) { + return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context); + } else { + throw new IllegalArgumentException(node.getComputationType()+" requires a sort field for computation"); + } + default: + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); } } @@ -84,4 +95,136 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr ); return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); } + + /** + * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan + * with configuration [dataField=salary, sortField=age, dataPoints=3] + * -- +- 'Project [ + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) + * -- AS WMA#702] + * . + * And the corresponded SQL query: + * . + * SELECT name, salary, + * ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 + + * nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 + + * nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA + * FROM employees + * ORDER BY age; + * + * @param visitor Visitor instance to process any UnresolvedExpression. + * @param node Trendline command's arguments. + * @param sortField Field used for window aggregation. + * @param context Context instance to retrieved Expression in resolved form. + * @return a NamedExpression instance which will calculate WMA with provided argument. + */ + private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + Field sortField, + CatalystPlanContext context) { + int dataPoints = node.getNumberOfDataPoints(); + //window lower boundary + Expression windowLowerBoundary = parseIntToExpression(visitor, context, + Math.negateExact(dataPoints - 1)); + //window definition + visitor.analyze(sortField, context); + Expression sortDefinition = context.popNamedParseExpressions().get(); + WindowSpecDefinition windowDefinition = getWmaCommonWindowDefinition( + sortDefinition, + SortUtils.isSortedAscending(sortField), + windowLowerBoundary); + // Divisor + Expression divisor = parseIntToExpression(visitor, context, + (dataPoints * (dataPoints + 1) / 2)); + // Aggregation + Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) + .stream() + .reduce(Add::new) + .orElse(null); + + return getAlias(node.getAlias(), new Divide(wmaExpression, divisor)); + } + + /** + * Helper method to produce an Alias Expression with provide value and name. + * @param name The name for the Alias. + * @param expression The expression which will be evaluated. + * @return An Alias instance with logical plan representation of `expression AS name`. + */ + private static NamedExpression getAlias(String name, Expression expression) { + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression, + name, + NamedExpression.newExprId(), + seq(Collections.emptyList()), + Option.empty(), + seq(Collections.emptyList())); + } + + /** + * Helper method to retrieve an Int expression instance for logical plan composition purpose. + * @param expressionVisitor Visitor instance to process the incoming object. + * @param context Context instance to retrieve the Expression instance. + * @param i Target value for the expression. + * @return An expression object which contain integer value i. + */ + static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + expressionVisitor.visitLiteral(new Literal(i, + DataType.INTEGER), context); + return context.popNamedParseExpressions().get(); + } + + + /** + * Helper method to retrieve a WindowSpecDefinition with provided sorting condition. + * `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())` + * + * @param sortField The field being used for the sorting operation. + * @param ascending The boolean instance for the sorting order. + * @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead. + * @return A WindowSpecDefinition instance which will be used to composite the WMA calculation. + */ + static WindowSpecDefinition getWmaCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { + return new WindowSpecDefinition( + seq(), + seq(SortUtils.sortOrder(sortField, ascending)), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + } + + /** + * To produce a list of Expressions responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. + * (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * + * @param visitor Visitor instance to resolve Expression. + * @param node Treeline command instruction. + * @param context Context instance to retrieve the resolved expression. + * @param windowDefinition The windowDefinition for the individual datapoint lookbehind / lookahead. + * @param dataPoints Number of data-points for WMA calculation, this will always equal to number of Expression being generated. + * @return List instance which contain the SQL statement for WMA individual datapoint's calculations. + */ + private static List getNthValueAggregations(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + CatalystPlanContext context, + WindowSpecDefinition windowDefinition, + int dataPoints) { + List expressions = new ArrayList<>(); + for (int i = 1; i <= dataPoints; i++) { + // Get the offset parameter + Expression offSetExpression = parseIntToExpression(visitor, context, i); + // Get the dataField in Expression + visitor.analyze(node.getDataField(), context); + Expression dataField = context.popNamedParseExpressions().get(); + // nth_value Expression + UnresolvedFunction nthValueExp = new UnresolvedFunction( + asScalaBufferConverter(List.of("nth_value")).asScala().seq(), + asScalaBufferConverter(List.of(dataField, offSetExpression)).asScala().seq(), + false, empty(), false); + + expressions.add(new Multiply( + new WindowExpression(nthValueExp, windowDefinition), offSetExpression)); + } + return expressions; + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d22750ee0..ec1775631 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -6,12 +6,15 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -132,4 +135,147 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite Project(trendlineProjectList, sort)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } + + test("WMA - with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age)"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#0] ! +- 'Sort ['age ASC NULLS FIRST], true ! +- + * 'UnresolvedRelation [relation], [], false + */ + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("WMA - with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - multiple trendline commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), + context) + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - with negative dataPoint value") { + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException]( + planTransformer + .visit(plan(pplParser, "source=relation | trendline sort age wma(-3, age)"), context)) + assert(exception.getMessage startsWith "Failed to parse query due to offending symbol [-]") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + }