diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index 00dbbc574..d49672ce4 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -135,7 +135,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite comparePlans(logPlan, expectedPlan, checkAnalysis = false) } - test("wma - with sort") { + test("WMA - with sort") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age)"), context) @@ -143,8 +143,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val wmaExpression = Divide(dividend, divisor) + val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) @@ -162,7 +161,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite comparePlans(logPlan, expectedPlan, checkAnalysis = false) } - test("wma - with sort and alias") { + test("WMA - with sort and alias") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), context) @@ -170,10 +169,8 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val wmaExpression = Divide(dividend, divisor) + val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) - val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) @@ -193,7 +190,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite } - test("wma - multiple trendline commands") { + test("WMA - multiple trendline commands") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), context) @@ -205,15 +202,11 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val threePointsExpression = Divide(dividend, divisor) - + val threePointsExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) - val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) - /** * Expected logical plan: * 'Project [*] @@ -233,8 +226,6 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite } - - private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { Multiply( WindowExpression(