Skip to content

Commit

Permalink
return null when there are too few data points
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
kt-eliatra authored and lukasz-soszynski-eliatra committed Oct 29, 2024
1 parent 07aa10c commit ffe0581
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CurrentRow, Descending, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -48,7 +48,7 @@ class FlintSparkPPLTrendlineITSuite
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "California", "USA", 2023, 4, 70.0),
Row("Jake", 70, "California", "USA", 2023, 4, null),
Row("Hello", 30, "New York", "USA", 2023, 4, 50.0),
Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5))
Expand All @@ -61,10 +61,15 @@ class FlintSparkPPLTrendlineITSuite
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val ageField = UnresolvedAttribute("age")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))
)
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(smaWindow, "age_sma")())
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow)
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")())
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand All @@ -79,8 +84,8 @@ class FlintSparkPPLTrendlineITSuite
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, 70.0),
Row("Hello", 30, 50.0),
Row("Jake", 70, null),
Row("Hello", 30, null),
Row("John", 25, 41.666666666666664),
Row("Jane", 20, 25))
// Compare the results
Expand All @@ -94,10 +99,14 @@ class FlintSparkPPLTrendlineITSuite
val ageField = UnresolvedAttribute("age")
val ageSmaField = UnresolvedAttribute("age_sma")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(smaWindow, "age_sma")())
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow)
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")())
val expectedPlan =
Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
Expand All @@ -113,8 +122,8 @@ class FlintSparkPPLTrendlineITSuite
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 20, 20, 20),
Row("John", 25, 22.5, 22.5),
Row("Jane", 20, null, null),
Row("John", 25, 22.5, null),
Row("Hello", 30, 27.5, 25.0),
Row("Jake", 70, 50.0, 41.666666666666664))
// Compare the results
Expand All @@ -129,16 +138,26 @@ class FlintSparkPPLTrendlineITSuite
val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma")
val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma")
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table)
val twoPointsCountWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))
)
val twoPointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val threePointsCountWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))
)
val threePointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val twoPointsCaseWhen = CaseWhen(Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), twoPointsSmaWindow)
val threePointsCaseWhen = CaseWhen(Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), threePointsSmaWindow)
val trendlineProjectList = Seq(
UnresolvedStar(None),
Alias(twoPointsSmaWindow, "two_points_sma")(),
Alias(threePointsSmaWindow, "three_points_sma")())
Alias(twoPointsCaseWhen, "two_points_sma")(),
Alias(threePointsCaseWhen, "three_points_sma")())
val expectedPlan = Project(
Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField),
Project(trendlineProjectList, sort))
Expand All @@ -155,7 +174,7 @@ class FlintSparkPPLTrendlineITSuite
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 40, 40.0),
Row("Jane", 40, null),
Row("John", 50, 45.0),
Row("Hello", 60, 55.0),
Row("Jake", 140, 100.0))
Expand All @@ -178,11 +197,16 @@ class FlintSparkPPLTrendlineITSuite
"doubled_age")()),
table)
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))
)
val doubleAgeSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow)
val trendlineProjectList =
Seq(UnresolvedStar(None), Alias(doubleAgeSmaWindow, "doubled_age_sma")())
Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")())
val expectedPlan = Project(
Seq(nameField, doubledAgeField, doubledAgeSmaField),
Project(trendlineProjectList, sort))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.RowFrame$;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame;
Expand Down Expand Up @@ -681,19 +682,30 @@ public Expression visitSpan(Span node, CatalystPlanContext context) {

@Override
public Expression visitTrendlineComputation(Trendline.TrendlineComputation node, CatalystPlanContext context) {
this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context);
Expression avgFunction = context.popNamedParseExpressions().get();
//window lower boundary
this.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context);
Expression windowLowerBoundary = context.popNamedParseExpressions().get();

//window definition
WindowSpecDefinition windowDefinition = new WindowSpecDefinition(
seq(),
seq(),
new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$));

if (node.getComputationType() == Trendline.TrendlineType.SMA) {
//calculate avg value of the data field
this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context);
Expression avgFunction = context.popNamedParseExpressions().get();

//sma window
WindowExpression sma = new WindowExpression(
avgFunction,
new WindowSpecDefinition(
seq(),
seq(),
new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)));
windowDefinition);

CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(sma, node, context);

return context.getNamedParseExpressions().push(
org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(sma,
org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull,
node.getAlias(),
NamedExpression.newExprId(),
seq(new java.util.ArrayList<String>()),
Expand All @@ -704,6 +716,28 @@ public Expression visitTrendlineComputation(Trendline.TrendlineComputation node,
}
}

private CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) {
//required number of data points
this.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context);
Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get();

//count data points function
this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context);
Expression countDataPointsFunction = context.popNamedParseExpressions().get();
//count data points window
WindowExpression countDataPointsWindow = new WindowExpression(
countDataPointsFunction,
trendlineWindow.windowSpec());

this.visitLiteral(new Literal(null, DataType.NULL), context);
Expression nullLiteral = context.popNamedParseExpressions().get();
Tuple2<Expression, Expression> nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>(
new LessThan(countDataPointsWindow, requiredNumberOfDataPoints),
nullLiteral
);
return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow));
}

@Override
public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) {
node.getField().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.NullType$;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.unsafe.types.UTF8String;
import org.opensearch.sql.ast.expression.SpanUnit;
import scala.collection.mutable.Seq;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.opensearch.sql.ast.expression.SpanUnit.DAY;
import static org.opensearch.sql.ast.expression.SpanUnit.HOUR;
Expand Down Expand Up @@ -67,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
return ShortType$.MODULE$;
case BYTE:
return ByteType$.MODULE$;
case UNDEFINED:
return NullType$.MODULE$;
default:
return StringType$.MODULE$;
}
Expand Down

0 comments on commit ffe0581

Please sign in to comment.