Skip to content

Commit

Permalink
Update scalafmt
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Nov 12, 2024
1 parent 4f1d79f commit 8f4efd5
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

package org.opensearch.flint.spark.ppl

import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.opensearch.sql.ppl.utils.SortUtils

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.opensearch.sql.ppl.utils.SortUtils

class FlintSparkPPLTrendlineITSuite
extends QueryTest
Expand Down Expand Up @@ -271,24 +272,29 @@ class FlintSparkPPLTrendlineITSuite

// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
val dividend = Add(
Add(
getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
getNthValueAggregation("age", "age", 3, -2))
val wmaExpression = Divide(dividend, Literal(6))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")())
val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)
val sortedTable = Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation)
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
global = true,
unresolvedRelation)
val expectedPlan =
Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))

/**
* Expected logical plan:
* 'Project [*]
* +- 'Project [*, (((
* ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) +
* ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) +
* ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#185]
* +- 'Sort ['age ASC NULLS FIRST], true
* +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
* Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1)
* windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2,
* currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST,
* specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3)
* windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2,
* currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +-
* 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
*/
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand Down Expand Up @@ -317,24 +323,30 @@ class FlintSparkPPLTrendlineITSuite

// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
val dividend = Add(
Add(
getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
getNthValueAggregation("age", "age", 3, -2))
val wmaExpression = Divide(dividend, Literal(6))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")())
val trendlineProjectList =
Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")())
val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)
val sortedTable = Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation)
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
global = true,
unresolvedRelation)
val expectedPlan =
Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))

/**
* 'Project [*]
* +- 'Project [*, (((
* ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) +
* ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) +
* ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185]
* +- 'Sort ['age ASC NULLS FIRST], true
* +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
* 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS
* FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2)
* windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2,
* currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST,
* specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +-
* 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default,
* flint_ppl_test], [], false
*/
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand All @@ -347,7 +359,15 @@ class FlintSparkPPLTrendlineITSuite
// Compare the headers
assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "two_points_wma", "three_points_wma")))
Array(
"name",
"age",
"state",
"country",
"year",
"month",
"two_points_wma",
"three_points_wma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Expand All @@ -364,32 +384,43 @@ class FlintSparkPPLTrendlineITSuite
// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1),
val dividendTwo = Add(
getNthValueAggregation("age", "age", 1, -1),
getNthValueAggregation("age", "age", 2, -1))
val twoPointsExpression = Divide(dividendTwo, Literal(3))

val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
val dividend = Add(
Add(
getNthValueAggregation("age", "age", 1, -2),
getNthValueAggregation("age", "age", 2, -2)),
getNthValueAggregation("age", "age", 3, -2))
val threePointsExpression = Divide(dividend, Literal(6))

val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")())
val trendlineProjectList = Seq(
UnresolvedStar(None),
Alias(twoPointsExpression, "two_points_wma")(),
Alias(threePointsExpression, "three_points_wma")())
val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)
val sortedTable = Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation)
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
global = true,
unresolvedRelation)
val expectedPlan =
Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable))

/**
* 'Project [*]
* +- 'Project [*, ((
* ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) +
* ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#247,
* 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS
* FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2)
* windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1,
* currentrow$())) * 2)) / 3) AS two_points_wma#247,
*
* (((
* ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) +
* ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) +
* ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248]
* +- 'Sort ['age ASC NULLS FIRST], true
* +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
* ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST,
* specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2)
* windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2,
* currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST,
* specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +-
* 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default,
* flint_ppl_test], [], false
*/
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand All @@ -400,9 +431,7 @@ class FlintSparkPPLTrendlineITSuite
| """.stripMargin)

// Compare the headers
assert(
frame.columns.sameElements(
Array("name", "doubled_age", "doubled_age_wma")))
assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Expand All @@ -418,44 +447,57 @@ class FlintSparkPPLTrendlineITSuite

// Compare the logical plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val dividend = Add(getNthValueAggregation("doubled_age", "age", 1, -1),
val dividend = Add(
getNthValueAggregation("doubled_age", "age", 1, -1),
getNthValueAggregation("doubled_age", "age", 2, -1))
val wmaExpression = Divide(dividend, Literal(3))
val trendlineProjectList = Seq(UnresolvedStar(None),
Alias(wmaExpression, "doubled_age_wma")())
val trendlineProjectList =
Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")())
val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq)
val doubledAged = Alias(UnresolvedFunction(seq("*"), seq(UnresolvedAttribute("age"), Literal(2)), isDistinct = false) , "doubled_age")()
val doubledAged = Alias(
UnresolvedFunction(
seq("*"),
seq(UnresolvedAttribute("age"), Literal(2)),
isDistinct = false),
"doubled_age")()
val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation)
val sortedTable = Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true,
doubleAgeProject)
val sortedTable =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject)
val expectedPlan = Project(
Seq(UnresolvedAttribute("name"),UnresolvedAttribute("doubled_age"),UnresolvedAttribute("doubled_age_wma")),
Project(trendlineProjectList, sortedTable ))
/**
*
'Project ['name, 'doubled_age, 'doubled_age_wma]
+- 'Project [*, ((
('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) +
('nth_value('doubled_age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS doubled_age_wma#288]
+- 'Sort ['age ASC NULLS FIRST], true
+- 'Project [*, '`*`('age, 2) AS doubled_age#287]
+- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
Seq(
UnresolvedAttribute("name"),
UnresolvedAttribute("doubled_age"),
UnresolvedAttribute("doubled_age_wma")),
Project(trendlineProjectList, sortedTable))

/**
* 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, ((
* ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST,
* specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2)
* windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1,
* currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +-
* 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog,
* default, flint_ppl_test], [], false
*/
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)

}

private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = {
private def getNthValueAggregation(
dataField: String,
sortField: String,
lookBackPos: Int,
lookBackRange: Int): Expression = {
Multiply(
WindowExpression(
UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false),
UnresolvedFunction(
"nth_value",
Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)),
isDistinct = false),
WindowSpecDefinition(
Seq(),
seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)),
SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow)
)),
SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))),
Literal(lookBackPos))
}

Expand Down
Loading

0 comments on commit 8f4efd5

Please sign in to comment.