Skip to content

Commit

Permalink
Addres comments
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Nov 7, 2024
1 parent a3b00a7 commit 02c517e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
3 changes: 2 additions & 1 deletion ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ trendlineClause
;

trendlineType
: (SMA | WMA)
: SMA
| WMA
;

kmeansCommand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,19 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr
* Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan
* with configuration [dataField=salary, sortField=age, dataPoints=3]
* -- +- 'Project [
* -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) +
* -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) +
* -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 3)) / 6)
* -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) +
* -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) +
* -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6)
* -- AS WMA#702]
* .
* And the corresponded SQL query:
* .
* SELECT name, salary,
* ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 +
* nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 +
* nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA
* FROM employees
* ORDER BY age;
*
* @param visitor Visitor instance to process any UnresolvedExpression.
* @param node Trendline command's arguments.
Expand All @@ -113,7 +122,7 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis
CatalystPlanContext context) {
int dataPoints = node.getNumberOfDataPoints();
//window lower boundary
Expression windowLowerBoundary = getIntExpression(visitor, context,
Expression windowLowerBoundary = parseIntToExpression(visitor, context,
Math.negateExact(dataPoints - 1));
//window definition
visitor.analyze(sortField, context);
Expand All @@ -123,15 +132,15 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis
SortUtils.isSortedAscending(sortField),
windowLowerBoundary);
// Divisor
Expression divisor = getIntExpression(visitor, context,
Expression divisor = parseIntToExpression(visitor, context,
(dataPoints * (dataPoints + 1) / 2));
// Aggregation
Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints)
Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints)
.stream()
.reduce(Add::new)
.orElse(null);

return getAlias(node.getAlias(), new Divide(WMAExpression, divisor));
return getAlias(node.getAlias(), new Divide(wmaExpression, divisor));
}

/**
Expand All @@ -156,7 +165,7 @@ private static NamedExpression getAlias(String name, Expression expression) {
* @param i Target value for the expression.
* @return An expression object which contain integer value i.
*/
static Expression getIntExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) {
static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) {
expressionVisitor.visitLiteral(new Literal(i,
DataType.INTEGER), context);
return context.popNamedParseExpressions().get();
Expand Down Expand Up @@ -198,7 +207,7 @@ private static List<Expression> getNthValueAggregations(CatalystExpressionVisito
List<Expression> expressions = new ArrayList<>();
for (int i = 1; i <= dataPoints; i++) {
// Get the offset parameter
Expression offSetExpression = getIntExpression(visitor, context, i);
Expression offSetExpression = parseIntToExpression(visitor, context, i);

// Composite the nth_value expression.
Function func = new Function(BuiltinFunctionName.NTH_VALUE.name(),
Expand Down

0 comments on commit 02c517e

Please sign in to comment.