Skip to content

Commit

Permalink
Update test cases
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Nov 7, 2024
1 parent ec05bf6 commit d376ec5
Showing 1 changed file with 11 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class FlintSparkPPLTrendlineITSuite
| source = $testTable | trendline sort + age wma(3, age)
| """.stripMargin)

// Compare the headers
assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "age_trendline")))
Expand All @@ -268,16 +269,12 @@ class FlintSparkPPLTrendlineITSuite
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// // scalastyle:off
// println(logicalPlan.toString())
// // scalastyle:on println

val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
getNthValueAggregation("age", "age", 3, -2))
val divisor = Literal(6)
val wmaExpression = Divide(dividend, divisor)
val wmaExpression = Divide(dividend, Literal(6))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")())
val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)
val sortedTable = Sort(
Expand All @@ -301,6 +298,7 @@ class FlintSparkPPLTrendlineITSuite
| source = $testTable | trendline sort + age wma(3, age) as trendline_alias
| """.stripMargin)

// Compare the headers
assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "trendline_alias")))
Expand All @@ -317,25 +315,18 @@ class FlintSparkPPLTrendlineITSuite
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical


val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
getNthValueAggregation("age", "age", 3, -2))
val divisor = Literal(6)
val wmaExpression = Divide(dividend, divisor)
val wmaExpression = Divide(dividend, Literal(6))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")())
val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)
val sortedTable = Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation)
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))

// scalastyle:off
println(logicalPlan.toString())
println(expectedPlan.toString())
// scalastyle:on println

/**
* 'Project [*]
* +- 'Project [*, (((
Expand All @@ -353,6 +344,7 @@ class FlintSparkPPLTrendlineITSuite
| source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma
| """.stripMargin)

// Compare the headers
assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "two_points_wma", "three_points_wma")))
Expand All @@ -369,11 +361,8 @@ class FlintSparkPPLTrendlineITSuite
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// TBC The logical plan
// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// scalastyle:off
println(logicalPlan.toString())
// scalastyle:on println

val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1),
getNthValueAggregation("age", "age", 2, -1))
Expand Down Expand Up @@ -405,11 +394,12 @@ class FlintSparkPPLTrendlineITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test rendline wma command on evaluated column") {
test("test trendline wma command on evaluated column") {
val frame = sql(s"""
| source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma
| """.stripMargin)

// Compare the headers
assert(
frame.columns.sameElements(
Array("name", "doubled_age", "doubled_age_wma")))
Expand All @@ -426,30 +416,22 @@ class FlintSparkPPLTrendlineITSuite
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// TBC The logical plan
// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val dividend = Add(getNthValueAggregation("doubled_age", "age", 1, -1),
getNthValueAggregation("doubled_age", "age", 2, -1))
val wmaExpression = Divide(dividend, Literal(3))
val trendlineProjectList = Seq(UnresolvedStar(None),
Alias(wmaExpression, "doubled_age_wma")())

val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)


val doubledAged = Alias(UnresolvedFunction(seq("*"), seq(UnresolvedAttribute("age"), Literal(2)), isDistinct = false) , "doubled_age")()
val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation)

val sortedTable = Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true,
doubleAgeProject)

val expectedPlan = Project(
Seq(UnresolvedAttribute("name"),UnresolvedAttribute("doubled_age"),UnresolvedAttribute("doubled_age_wma")),
Project(trendlineProjectList, sortedTable ))


/**
*
'Project ['name, 'doubled_age, 'doubled_age_wma]
Expand Down

0 comments on commit d376ec5

Please sign in to comment.