diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index 7c10a6dd0..589cad33b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -5,13 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.opensearch.sql.ppl.utils.SortUtils class FlintSparkPPLTrendlineITSuite extends QueryTest @@ -271,24 +272,29 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** - * Expected logical plan: - * 'Project [*] - * +- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#185] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -317,24 +323,30 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) /** - * 'Project [*] - * +- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -347,7 +359,15 @@ class FlintSparkPPLTrendlineITSuite // Compare the headers assert( frame.columns.sameElements( - Array("name", "age", "state", "country", "year", "month", "two_points_wma", "three_points_wma"))) + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "two_points_wma", + "three_points_wma"))) // Retrieve the results val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = @@ -364,32 +384,43 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), getNthValueAggregation("age", "age", 2, -1)) val twoPointsExpression = Divide(dividendTwo, Literal(3)) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val threePointsExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** - * 'Project [*] - * +- 'Project [*, (( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#247, + * 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS two_points_wma#247, * - * ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -400,9 +431,7 @@ class FlintSparkPPLTrendlineITSuite | """.stripMargin) // Compare the headers - assert( - frame.columns.sameElements( - Array("name", "doubled_age", "doubled_age_wma"))) + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma"))) // Retrieve the results val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = @@ -418,44 +447,57 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(getNthValueAggregation("doubled_age", "age", 1, -1), + val dividend = Add( + getNthValueAggregation("doubled_age", "age", 1, -1), getNthValueAggregation("doubled_age", "age", 2, -1)) val wmaExpression = Divide(dividend, Literal(3)) - val trendlineProjectList = Seq(UnresolvedStar(None), - Alias(wmaExpression, "doubled_age_wma")()) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) - val doubledAged = Alias(UnresolvedFunction(seq("*"), seq(UnresolvedAttribute("age"), Literal(2)), isDistinct = false) , "doubled_age")() + val doubledAged = Alias( + UnresolvedFunction( + seq("*"), + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false), + "doubled_age")() val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) - val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, - doubleAgeProject) + val sortedTable = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) val expectedPlan = Project( - Seq(UnresolvedAttribute("name"),UnresolvedAttribute("doubled_age"),UnresolvedAttribute("doubled_age_wma")), - Project(trendlineProjectList, sortedTable )) - /** - * - 'Project ['name, 'doubled_age, 'doubled_age_wma] - +- 'Project [*, (( - ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - ('nth_value('doubled_age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS doubled_age_wma#288] - +- 'Sort ['age ASC NULLS FIRST], true - +- 'Project [*, '`*`('age, 2) AS doubled_age#287] - +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("doubled_age"), + UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable)) + /** + * 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, (( + * ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog, + * default, flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { Multiply( WindowExpression( - UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false), + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), WindowSpecDefinition( Seq(), seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), - SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow) - )), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), Literal(lookBackPos)) } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d49672ce4..baf472a08 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -138,25 +138,32 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite test("WMA - with sort") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age)"), context) + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age)"), + context) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), - getNthValueAggregation("age", "age", 3, -2)) + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** - * Expected logical plan: - * 'Project [*] - * !+- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#0] - * ! +- 'Sort ['age ASC NULLS FIRST], true - * ! +- 'UnresolvedRelation [relation], [], false + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#0] ! +- 'Sort ['age ASC NULLS FIRST], true ! +- + * 'UnresolvedRelation [relation], [], false */ comparePlans(logPlan, expectedPlan, checkAnalysis = false) } @@ -164,28 +171,34 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite test("WMA - with sort and alias") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), context) + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), + context) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, UnresolvedRelation(Seq("relation"))) /** - * Expected logical plan: - * 'Project [*] - * !+- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] - * ! +- 'Sort ['age ASC NULLS FIRST], true - * ! +- 'UnresolvedRelation [relation], [], false + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false */ - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } @@ -193,50 +206,67 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite test("WMA - multiple trendline commands") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), context) + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), + context) - val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), getNthValueAggregation("age", "age", 2, -1)) val twoPointsExpression = Divide(dividendTwo, Literal(3)) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val threePointsExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, UnresolvedRelation(Seq("relation"))) + /** - * Expected logical plan: - * 'Project [*] - * +- 'Project [*, (( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * Expected logical plan: 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, * - * ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [relation], [], false + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [relation], [], false */ - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } - private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { Multiply( WindowExpression( - UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false), + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), WindowSpecDefinition( Seq(), seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), - SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow) - )), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), Literal(lookBackPos)) } - }