From ffe0581b37e56e955be2e9cc703f561c790be055 Mon Sep 17 00:00:00 2001 From: Kacper Trochimiak Date: Tue, 29 Oct 2024 16:27:26 +0100 Subject: [PATCH] return null when there are too few data points Signed-off-by: Kacper Trochimiak --- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 48 +++++++++++++----- .../sql/ppl/CatalystQueryPlanVisitor.java | 50 ++++++++++++++++--- .../sql/ppl/utils/DataTypeTransformer.java | 6 +-- 3 files changed, 81 insertions(+), 23 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index 59a70aed6..111daf892 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CurrentRow, Descending, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -48,7 +48,7 @@ class FlintSparkPPLTrendlineITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("Jake", 70, "California", "USA", 2023, 4, 70.0), + Row("Jake", 70, "California", "USA", 2023, 4, null), Row("Hello", 30, "New York", "USA", 2023, 4, 50.0), Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5), Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5)) @@ -61,10 +61,15 @@ class FlintSparkPPLTrendlineITSuite val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val ageField = UnresolvedAttribute("age") val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)) + ) val smaWindow = WindowExpression( UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(smaWindow, "age_sma")()) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -79,8 +84,8 @@ class FlintSparkPPLTrendlineITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("Jake", 70, 70.0), - Row("Hello", 30, 50.0), + Row("Jake", 70, null), + Row("Hello", 30, null), Row("John", 25, 41.666666666666664), Row("Jane", 20, 25)) // Compare the results @@ -94,10 +99,14 @@ class FlintSparkPPLTrendlineITSuite val ageField = UnresolvedAttribute("age") val ageSmaField = UnresolvedAttribute("age_sma") val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) val smaWindow = WindowExpression( UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(smaWindow, "age_sma")()) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) val expectedPlan = Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort)) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) @@ -113,8 +122,8 @@ class FlintSparkPPLTrendlineITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("Jane", 20, 20, 20), - Row("John", 25, 22.5, 22.5), + Row("Jane", 20, null, null), + Row("John", 25, 22.5, null), Row("Hello", 30, 27.5, 25.0), Row("Jake", 70, 50.0, 41.666666666666664)) // Compare the results @@ -129,16 +138,26 @@ class FlintSparkPPLTrendlineITSuite val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma") val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)) + ) val twoPointsSmaWindow = WindowExpression( UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)) + ) val threePointsSmaWindow = WindowExpression( UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen(Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen(Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), threePointsSmaWindow) val trendlineProjectList = Seq( UnresolvedStar(None), - Alias(twoPointsSmaWindow, "two_points_sma")(), - Alias(threePointsSmaWindow, "three_points_sma")()) + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "three_points_sma")()) val expectedPlan = Project( Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField), Project(trendlineProjectList, sort)) @@ -155,7 +174,7 @@ class FlintSparkPPLTrendlineITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("Jane", 40, 40.0), + Row("Jane", 40, null), Row("John", 50, 45.0), Row("Hello", 60, 55.0), Row("Jake", 140, 100.0)) @@ -178,11 +197,16 @@ class FlintSparkPPLTrendlineITSuite "doubled_age")()), table) val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)) + ) val doubleAgeSmaWindow = WindowExpression( UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false), WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow) val trendlineProjectList = - Seq(UnresolvedStar(None), Alias(doubleAgeSmaWindow, "doubled_age_sma")()) + Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")()) val expectedPlan = Project( Seq(nameField, doubledAgeField, doubledAgeSmaField), Project(trendlineProjectList, sort)) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 3c9716b5d..876ede1db 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -19,12 +19,13 @@ import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; import org.apache.spark.sql.catalyst.expressions.RowFrame$; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; @@ -681,19 +682,30 @@ public Expression visitSpan(Span node, CatalystPlanContext context) { @Override public Expression visitTrendlineComputation(Trendline.TrendlineComputation node, CatalystPlanContext context) { - this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); - Expression avgFunction = context.popNamedParseExpressions().get(); + //window lower boundary this.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + + //window definition + WindowSpecDefinition windowDefinition = new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + //calculate avg value of the data field + this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window WindowExpression sma = new WindowExpression( avgFunction, - new WindowSpecDefinition( - seq(), - seq(), - new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$))); + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(sma, node, context); + return context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(sma, + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, node.getAlias(), NamedExpression.newExprId(), seq(new java.util.ArrayList()), @@ -704,6 +716,28 @@ public Expression visitTrendlineComputation(Trendline.TrendlineComputation node, } } + private CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //required number of data points + this.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context); + Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get(); + + //count data points function + this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context); + Expression countDataPointsFunction = context.popNamedParseExpressions().get(); + //count data points window + WindowExpression countDataPointsWindow = new WindowExpression( + countDataPointsFunction, + trendlineWindow.windowSpec()); + + this.visitLiteral(new Literal(null, DataType.NULL), context); + Expression nullLiteral = context.popNamedParseExpressions().get(); + Tuple2 nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>( + new LessThan(countDataPointsWindow, requiredNumberOfDataPoints), + nullLiteral + ); + return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); + } + @Override public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { node.getField().accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 62eef90ed..e4defad52 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -14,16 +14,14 @@ import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.NullType$; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -67,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; + case UNDEFINED: + return NullType$.MODULE$; default: return StringType$.MODULE$; }