Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
kt-eliatra committed Oct 11, 2024
1 parent 080844b commit 96ed7a5
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CurrentRow, Descending, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

class FlintSparkPPLTrendlineITSuite
extends QueryTest
Expand All @@ -35,28 +36,152 @@ class FlintSparkPPLTrendlineITSuite
}
}

test("trendline sma") {
test("test trendline sma command without fields command") {
val frame = sql(s"""
| source = $testTable | trendline sma(2, age) as first_age_sma sma(3, age) as second_age_sma | fields name, first_age_sma, second_age_sma
| source = $testTable | sort - age | trendline sma(2, age) as age_sma
| """.stripMargin)

assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array()

// Convert actual results to a Set for quick lookup
val resultsSet: Set[Row] = results.toSet
// Check that each expected row is present in the actual results
expectedResults.foreach { expectedRow =>
assert(resultsSet.contains(expectedRow), s"Expected row $expectedRow not found in results")
}
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "California", "USA", 2023, 4, 70.0),
Row("Hello", 30, "New York", "USA", 2023, 4, 50.0),
Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val ageField = UnresolvedAttribute("age")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(smaWindow, "age_sma")())
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command with fields command") {
val frame = sql(s"""
| source = $testTable | sort - age | trendline sma(3, age) as age_sma | fields name, age, age_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "age", "age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, 70.0),
Row("Hello", 30, 50.0),
Row("John", 25, 41.666666666666664),
Row("Jane", 20, 25))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val ageSmaField = UnresolvedAttribute("age_sma")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(smaWindow, "age_sma")())
val expectedPlan =
Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple trendline sma commands") {
val frame = sql(s"""
| source = $testTable | sort + age | trendline sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 20, 20, 20),
Row("John", 25, 22.5, 22.5),
Row("Hello", 30, 27.5, 25.0),
Row("Jake", 70, 50.0, 41.666666666666664))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma")
val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma")
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table)
val twoPointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val threePointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val trendlineProjectList = Seq(
UnresolvedStar(None),
Alias(twoPointsSmaWindow, "two_points_sma")(),
Alias(threePointsSmaWindow, "three_points_sma")())
val expectedPlan = Project(
Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField),
Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command on evaluated column") {
val frame = sql(s"""
| source = $testTable | sort + age
| | eval doubled_age = age * 2 | trendline sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 40, 40.0),
Row("John", 50, 45.0),
Row("Hello", 60, 55.0),
Row("Jake", 140, 100.0))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan =
frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan
// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(Seq(), new UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val doubledAgeField = UnresolvedAttribute("doubled_age")
val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma")
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table)
val evalProject = Project(Seq(UnresolvedStar(None), Alias(UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false), "doubled_age")()), sort)
val doubleAgeSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val trendlineProjectList = Seq(
UnresolvedStar(None),
Alias(doubleAgeSmaWindow, "doubled_age_sma")())
val expectedPlan = Project(
Seq(nameField, doubledAgeField, doubledAgeSmaField),
Project(trendlineProjectList, evalProject))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,16 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {

@Override
public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
node.getChild().get(0).accept(this, context);

if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
}
visitExpressionList(node.getComputations(), context);
return child;
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);

return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.ppl.utils.ArgumentFactory;

Expand Down Expand Up @@ -116,7 +117,10 @@ public UnresolvedExpression visitEvalClause(OpenSearchPPLParser.EvalClauseContex

@Override
public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) {
Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText());
int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText());
if (numberOfDataPoints < 0) {
throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 0");
}
Field dataField = (Field) this.visitFieldExpression(ctx.field);
String alias = ctx.alias.getText();
String computationType = ctx.trendlineType().getText();
Expand Down

0 comments on commit 96ed7a5

Please sign in to comment.