From bdb4848960c3b1cc54a37c77c47af3727df4c265 Mon Sep 17 00:00:00 2001 From: Hendrik Saly Date: Fri, 1 Nov 2024 20:22:33 +0100 Subject: [PATCH] New trendline ppl command (SMA only) (#833) * WIP trendline command Signed-off-by: Kacper Trochimiak * wip Signed-off-by: Kacper Trochimiak * trendline supports sorting Signed-off-by: Kacper Trochimiak * run scalafmtAll Signed-off-by: Kacper Trochimiak * return null when there are too few data points Signed-off-by: Kacper Trochimiak * sbt scalafmtAll Signed-off-by: Kacper Trochimiak * Remove WMA references Signed-off-by: Hendrik Saly * trendline - sortByField as Optional Signed-off-by: Kacper Trochimiak * introduce TrendlineStrategy Signed-off-by: Kacper Trochimiak * keywordsCanBeId -> replace SMA with trendlineType Signed-off-by: Kacper Trochimiak * handle trendline alias as qualifiedName instead of fieldExpression Signed-off-by: Kacper Trochimiak * Add docs Signed-off-by: Hendrik Saly * Make alias optional Signed-off-by: Hendrik Saly * Adapt tests for optional alias Signed-off-by: Hendrik Saly * Adden logical plan unittests Signed-off-by: Hendrik Saly * Add missing license headers Signed-off-by: Hendrik Saly * Fix docs Signed-off-by: Hendrik Saly * numberOfDataPoints must be 1 or greater Signed-off-by: Hendrik Saly * Rename TrendlineStrategy to TrendlineCatalystUtils Signed-off-by: Hendrik Saly * Validate TrendlineType early and pass around enum type Signed-off-by: Hendrik Saly * Add trendline chaining test Signed-off-by: Hendrik Saly * Fix compile errors Signed-off-by: Hendrik Saly * Fix imports Signed-off-by: Hendrik Saly * Fix imports Signed-off-by: Hendrik Saly --------- Signed-off-by: Kacper Trochimiak Signed-off-by: Hendrik Saly Co-authored-by: Kacper Trochimiak --- docs/ppl-lang/PPL-Example-Commands.md | 1 + docs/ppl-lang/README.md | 3 +- docs/ppl-lang/ppl-trendline-command.md | 60 +++++ .../ppl/FlintSparkPPLTrendlineITSuite.scala | 247 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 4 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 14 + .../sql/ast/AbstractNodeVisitor.java | 4 + .../opensearch/sql/ast/tree/Trendline.java | 67 +++++ .../sql/ppl/CatalystExpressionVisitor.java | 9 + .../sql/ppl/CatalystQueryPlanVisitor.java | 73 +++--- .../opensearch/sql/ppl/parser/AstBuilder.java | 24 ++ .../sql/ppl/parser/AstExpressionBuilder.java | 2 + .../sql/ppl/utils/DataTypeTransformer.java | 6 +- .../opensearch/sql/ppl/utils/SortUtils.java | 6 +- .../sql/ppl/utils/TrendlineCatalystUtils.java | 87 ++++++ ...nTrendlineCommandTranslatorTestSuite.scala | 135 ++++++++++ 16 files changed, 698 insertions(+), 44 deletions(-) create mode 100644 docs/ppl-lang/ppl-trendline-command.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 561b5b27b..409b128c9 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -60,6 +60,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` +- `source = table | trendline sma(2, temperature) as temp_trend` ```sql source = table | eval status_category = diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 43e9579aa..6ba49b031 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -69,7 +69,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`subquery commands`](ppl-subquery-command.md) - [`correlation commands`](ppl-correlation-command.md) - + + - [`trendline commands`](ppl-trendline-command.md) * **Functions** diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md new file mode 100644 index 000000000..393a9dd59 --- /dev/null +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -0,0 +1,60 @@ +## PPL trendline Command + +**Description** +Using ``trendline`` command to calculate moving averages of fields. + + +### Syntax +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving average should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +And the moment only the Simple Moving Average (SMA) type is supported. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t + +### Example 1: Calculate simple moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.5| + | 14| 257|2023-04-06 17:07:...| 13.5| + | 15| 258|2023-04-06 17:07:...| 14.5| + +-----------+---------+--------------------+----------+ + +### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.5| NULL| + | 13| 256|2023-04-06 17:07:...| 13.5| 14.0| + | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| + | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| + +-----------+---------+--------------------+------------+------------------+ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala new file mode 100644 index 000000000..bc4463537 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTrendlineITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test trendline sma command without fields command and without alias") { + val frame = sql(s""" + | source = $testTable | sort - age | trendline sma(2, age) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", "USA", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 50.0), + Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command with fields command") { + val frame = sql(s""" + | source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, null), + Row("Hello", 30, null), + Row("John", 25, 41.666666666666664), + Row("Jane", 20, 25)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageSmaField = UnresolvedAttribute("age_sma") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline sma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, null, null), + Row("John", 25, 22.5, null), + Row("Hello", 30, 27.5, 25.0), + Row("Jake", 70, 50.0, 41.666666666666664)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "three_points_sma")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 45.0), + Row("Hello", 60, 55.0), + Row("Jake", 140, 100.0)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val doubledAgeField = UnresolvedAttribute("doubled_age") + val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma") + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false), + "doubled_age")()), + table) + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val doubleAgeSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = + CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")()) + val expectedPlan = Project( + Seq(nameField, doubledAgeField, doubledAgeSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command chaining") { + val frame = sql(s""" + | source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "age_1", + "age_2", + "age_1_trendline", + "age_2_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 58d10a560..991a4dffe 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -38,6 +38,7 @@ AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; FLATTEN: 'FLATTEN'; +TRENDLINE: 'TRENDLINE'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -90,6 +91,9 @@ FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; NULLS: 'NULLS'; +//TRENDLINE KEYWORDS +SMA: 'SMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 8bb93567b..cd6fe5dc1 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -54,6 +54,7 @@ commands | fillnullCommand | fieldsummaryCommand | flattenCommand + | trendlineCommand ; commandName @@ -84,6 +85,7 @@ commandName | FILLNULL | FIELDSUMMARY | FLATTEN + | TRENDLINE ; searchCommand @@ -252,6 +254,17 @@ flattenCommand : FLATTEN fieldExpression ; +trendlineCommand + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + ; + +trendlineType + : SMA + ; kmeansCommand : KMEANS (kmeansParameter)* @@ -1131,4 +1144,5 @@ keywordsCanBeId | ANTI | BETWEEN | CIDRMATCH + | trendlineType ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 00db5b675..525a0954c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -111,6 +111,10 @@ public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 000000000..9fa1ae81d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final Optional sortByField; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, Trendline.TrendlineType computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = computationType; + } + + } + + public enum TrendlineType { + SMA + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 571905f8a..69a89b83a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -9,18 +9,24 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -32,6 +38,7 @@ import org.opensearch.sql.ast.expression.BinaryExpression; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -54,7 +61,9 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.SerializableUdf; import org.opensearch.sql.ppl.utils.AggregatorTransformer; import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 3ad1b95cb..669459fba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,29 +6,30 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; -import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; -import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; -import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Generate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; @@ -36,35 +37,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Between; -import org.opensearch.sql.ast.expression.BinaryExpression; -import org.opensearch.sql.ast.expression.Case; -import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; -import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -90,20 +72,16 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; -import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.expression.function.SerializableUdf; -import org.opensearch.sql.ppl.utils.AggregatorTransformer; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; -import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; import scala.None$; import scala.Option; -import scala.Tuple2; import scala.collection.IterableLike; import scala.collection.Seq; @@ -111,16 +89,11 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Stack; -import java.util.function.BiFunction; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; -import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; @@ -132,8 +105,6 @@ import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields; import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; -import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; -import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; import static scala.collection.JavaConverters.seqAsJavaList; /** @@ -255,6 +226,30 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { }); } + @Override + public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + + node.getSortByField() + .ifPresent(sortField -> { + Expression sortFieldExpression = visitExpression(sortField, context); + Seq sortOrder = context + .retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(sortField))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortOrder, true, p)); + }); + + List trendlineProjectExpressions = new ArrayList<>(); + + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + } + + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); + } + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 09db8b126..4e6b1f131 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -386,6 +386,30 @@ private java.util.Map buildLookupPair(List (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); } + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(this::toTrendlineComputation) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); + } + + private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParser.TrendlineClauseContext ctx) { + int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + } + Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field); + String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase())); + } + /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index b6dfd0447..5e0f0775d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -43,6 +43,8 @@ import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 62eef90ed..e4defad52 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -14,16 +14,14 @@ import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.NullType$; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -67,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; + case UNDEFINED: + return NullType$.MODULE$; default: return StringType$.MODULE$; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java index 83603b031..803daea8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -38,7 +38,7 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) { .findAny(); return field.map(value -> sortOrder((Expression) expression, - (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + isSortedAscending(value))) .orElse(null); } @@ -51,4 +51,8 @@ static SortOrder sortOrder(Expression expression, boolean ascending) { seq(new ArrayList()) ); } + + static boolean isSortedAscending(Field field) { + return (Boolean) field.getFieldArgs().get(0).getValue().getValue(); + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java new file mode 100644 index 000000000..67603ccc7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.*; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; +import scala.Tuple2; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface TrendlineCatalystUtils { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + return computations.stream() + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .collect(Collectors.toList()); + } + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //window lower boundary + expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); + Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + + //window definition + WindowSpecDefinition windowDefinition = new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, + node.getAlias(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList())); + } else { + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + } + } + + private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //required number of data points + expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context); + Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get(); + + //count data points function + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context); + Expression countDataPointsFunction = context.popNamedParseExpressions().get(); + //count data points window + WindowExpression countDataPointsWindow = new WindowExpression( + countDataPointsFunction, + trendlineWindow.windowSpec()); + + expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context); + Expression nullLiteral = context.popNamedParseExpressions().get(); + Tuple2 nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>( + new LessThan(countDataPointsWindow, requiredNumberOfDataPoints), + nullLiteral + ); + return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..d22750ee0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanTrendlineCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test trendline") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sma(3, age)"), context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, table)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age sma(3, age)"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with multiple trendline sma commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort + age sma(2, age) as two_points_sma sma(3, age) | fields name, age, two_points_sma, age_trendline"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageTrendlineField = UnresolvedAttribute("age_trendline") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "age_trendline")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageTrendlineField), + Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } +}