diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index b43435b24..66bc6fe1d 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -263,7 +263,8 @@ trendlineClause ; trendlineType - : (SMA | WMA) + : SMA + | WMA ; kmeansCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 307041bd4..513561bfa 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -96,10 +96,19 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan * with configuration [dataField=salary, sortField=age, dataPoints=3] * -- +- 'Project [ - * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) + - * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 3)) / 6) + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) * -- AS WMA#702] + * . + * And the corresponded SQL query: + * . + * SELECT name, salary, + * ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 + + * nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 + + * nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA + * FROM employees + * ORDER BY age; * * @param visitor Visitor instance to process any UnresolvedExpression. * @param node Trendline command's arguments. @@ -113,7 +122,7 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis CatalystPlanContext context) { int dataPoints = node.getNumberOfDataPoints(); //window lower boundary - Expression windowLowerBoundary = getIntExpression(visitor, context, + Expression windowLowerBoundary = parseIntToExpression(visitor, context, Math.negateExact(dataPoints - 1)); //window definition visitor.analyze(sortField, context); @@ -123,15 +132,15 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis SortUtils.isSortedAscending(sortField), windowLowerBoundary); // Divisor - Expression divisor = getIntExpression(visitor, context, + Expression divisor = parseIntToExpression(visitor, context, (dataPoints * (dataPoints + 1) / 2)); // Aggregation - Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) + Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) .stream() .reduce(Add::new) .orElse(null); - return getAlias(node.getAlias(), new Divide(WMAExpression, divisor)); + return getAlias(node.getAlias(), new Divide(wmaExpression, divisor)); } /** @@ -156,7 +165,7 @@ private static NamedExpression getAlias(String name, Expression expression) { * @param i Target value for the expression. * @return An expression object which contain integer value i. */ - static Expression getIntExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { expressionVisitor.visitLiteral(new Literal(i, DataType.INTEGER), context); return context.popNamedParseExpressions().get(); @@ -198,7 +207,7 @@ private static List getNthValueAggregations(CatalystExpressionVisito List expressions = new ArrayList<>(); for (int i = 1; i <= dataPoints; i++) { // Get the offset parameter - Expression offSetExpression = getIntExpression(visitor, context, i); + Expression offSetExpression = parseIntToExpression(visitor, context, i); // Composite the nth_value expression. Function func = new Function(BuiltinFunctionName.NTH_VALUE.name(),