Skip to content

Commit

Permalink
Java doc
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-k-improving committed Nov 5, 2024
1 parent 49383be commit 5a8592f
Showing 1 changed file with 134 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -40,154 +41,193 @@ static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expre
Expression windowLowerBoundary = context.popNamedParseExpressions().get();

//window definition
// windowspecdefinition(specifiedwindowframe(RowFrame, -2, currentrow$()
WindowSpecDefinition windowDefinition = new WindowSpecDefinition(
seq(),
seq(),
new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$));

if (node.getComputationType() == Trendline.TrendlineType.SMA) {
//calculate avg value of the data field
expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context);
Expression avgFunction = context.popNamedParseExpressions().get();

//sma window
WindowExpression sma = new WindowExpression(
avgFunction,
windowDefinition);

CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context);

return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull,
node.getAlias(),
NamedExpression.newExprId(),
seq(new java.util.ArrayList<String>()),
Option.empty(),
seq(new java.util.ArrayList<String>()));
} else if (node.getComputationType() == Trendline.TrendlineType.WMA) {
if (sortField.isPresent()) {
return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context);
} else {
throw new IllegalArgumentException(node.getComputationType()+" require a sort field for computation");
}
} else {
throw new IllegalArgumentException(node.getComputationType()+" is not supported");
switch (node.getComputationType()) {
case SMA:
//calculate avg value of the data field
expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context);
Expression avgFunction = context.popNamedParseExpressions().get();

//sma window
WindowExpression sma = new WindowExpression(
avgFunction,
windowDefinition);

CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context);

return getAlias(node.getAlias(), smaOrNull);
case WMA:
if (sortField.isPresent()) {
return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context);
} else {
throw new IllegalArgumentException(node.getComputationType()+" require a sort field for computation");
}
default:
throw new IllegalArgumentException(node.getComputationType()+" is not supported");
}
}

private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) {
//required number of data points
expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context);
Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get();

//count data points function
expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context);
Expression countDataPointsFunction = context.popNamedParseExpressions().get();
//count data points window
WindowExpression countDataPointsWindow = new WindowExpression(
countDataPointsFunction,
trendlineWindow.windowSpec());

expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context);
Expression nullLiteral = context.popNamedParseExpressions().get();
Tuple2<Expression, Expression> nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>(
new LessThan(countDataPointsWindow, requiredNumberOfDataPoints),
nullLiteral
);
return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow));
}

/**
* Produce a Spark Logical Plan in the form NamedExpression with given WindowSpecDefinition.
*
*/
private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor analyzer,

/**
* Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan
* with configuration [dataField=salary, sortField=age, dataPoints=3]
* -- +- 'Project [
* -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) +
* -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) +
* -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 3)) / 6)
* -- AS WMA#702]
*
* @param visitor Visitor instance to process any UnresolvedExpression.
* @param node Trendline command's arguments.
* @param sortField Field used for window aggregation.
* @param context Context instance to retrieved Expression in resolved form.
* @return a NamedExpression instance which will calculate WMA with provided argument.
*/
private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor visitor,
Trendline.TrendlineComputation node,
Field sortField,
CatalystPlanContext context) {

System.out.println("Func 4");
//window lower boundary
Expression windowLowerBoundary = getIntExpression(analyzer, context,
Expression windowLowerBoundary = getIntExpression(visitor, context,
Math.negateExact(node.getNumberOfDataPoints() - 1));
// The field name
Expression dataField = getStringExpression(analyzer, context, node.getDataField());
//window definition
WindowSpecDefinition windowDefinition = getCommonWindowDefinition(
analyzer.analyze(sortField, context),
visitor.analyze(sortField, context),
SortUtils.isSortedAscending(sortField),
windowLowerBoundary);
// Divider
Expression divider = getIntExpression(analyzer, context,
Expression divider = getIntExpression(visitor, context,
(node.getNumberOfDataPoints() * (node.getNumberOfDataPoints()+1) / 2));
// Aggregation
Expression WMAExpression = sum(
getNthValueAggregations(analyzer, node, context, windowDefinition, node.getNumberOfDataPoints()));
Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition,
node.getNumberOfDataPoints())
.stream()
.reduce(Add::new)
.orElse(null);

return getAlias(node.getAlias(), new Divide(WMAExpression, divider));
}

/**
* Helper method to produce an Alias Expression with provide value and name.
* @param name The name for the Alias.
* @param expression The expression which will be evaluated.
* @return A Alias instance with logical plan representation of `expression AS name`.
*/
private static NamedExpression getAlias(String name, Expression expression) {
return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression,
name,
NamedExpression.newExprId(),
seq(Collections.emptyList()),
Option.empty(),
seq(Collections.emptyList()));
}

/**
* Helper method to retrieve an Int in expression form for logical plan composition purpose.
* @param expressionVisitor Visitor instance to process the incoming object.
* @param context Context instance to retrieve the Expression instance.
* @param i Target value for the expression.
* @return An expression object which contain integer value i.
*/
static Expression getIntExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) {
expressionVisitor.visitLiteral(new Literal(i,
DataType.INTEGER), context);
return context.popNamedParseExpressions().get();

}

static Expression getStringExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, UnresolvedExpression exp) {
expressionVisitor.visitLiteral(new Literal(exp,
/**
* Helper method to retrieve a String in expression form for logical plan composition purpose.
* @param expressionVisitor Visitor instance to process the incoming object.
* @param context Context instance to retrieve the Expression instance.
* @param unresolvedExpression The unresolvedExpression instance that required the conversion.
* @return An expression object which contain the incoming object.
*/
static Expression getStringExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, UnresolvedExpression unresolvedExpression) {
expressionVisitor.visitLiteral(new Literal(unresolvedExpression,
DataType.STRING), context);
return context.popNamedParseExpressions().get();

}

static WindowSpecDefinition getCommonWindowDefinition(Expression dataField, boolean ascending, Expression windowLowerBoundary) {
/**
* Helper method to retrieve a WindowSpecDefinition with provided sorting condition.
* `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())`
* @param sortField The field being used for the sorting operation.
* @param ascending The boolean instance for the sorting order.
* @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead.
* @return A WindowSpecDefinition instance which will be used to composite the WMA calculation.
*/
static WindowSpecDefinition getCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) {
return new WindowSpecDefinition(
seq(),
seq(SortUtils.sortOrder(dataField, ascending)),
seq(SortUtils.sortOrder(sortField, ascending)),
new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$));
}

private static @NotNull List<Expression> getNthValueAggregations(CatalystExpressionVisitor expressionVisitor,
Trendline.TrendlineComputation node,
CatalystPlanContext context,
WindowSpecDefinition windowDefinition, int offSet) {

List<Expression> expressions = new ArrayList<Expression>();
for (int i = 1; i <= offSet; i++) {
/**
* To produce a list of Expression with responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below.
* (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) +
*
* @param visitor Visitor instance to resolve Expression.
* @param node Treeline command instruction.
* @param context Context instance to retrieve the resolved expression.
* @param windowDefinition The windowDefinition for the inidividual datapoint lookbehind / lookahead.
* @param dataPoints Number of datapoints for WMA calculation, this will always equal to number of Expression being generated.
* @return List instance which contain the SQL statement for WMA individual datapoint's calculations.
*/
private static List<Expression> getNthValueAggregations(CatalystExpressionVisitor visitor,
Trendline.TrendlineComputation node,
CatalystPlanContext context,
WindowSpecDefinition windowDefinition,
int dataPoints) {

List<Expression> expressions = new ArrayList<>();
for (int i = 1; i <= dataPoints; i++) {
// Get the offset parameter
Literal offSetLiteral = new Literal(i, DataType.INTEGER);
expressionVisitor.visitLiteral(offSetLiteral, context);
Expression offSetExpression = context.popNamedParseExpressions().get();
Expression offSetExpression = getIntExpression(visitor, context, i);

// Composite the nth_value expression.
Function func = new Function(BuiltinFunctionName.NTH_VALUE.name(),
List.of(node.getDataField(), offSetLiteral));
List.of(node.getDataField(), new Literal(i, DataType.INTEGER)));

expressionVisitor.visitFunction(func, context);
visitor.visitFunction(func, context);
Expression nthValueExp = context.popNamedParseExpressions().get();

expressions.add(
new Multiply(new WindowExpression(nthValueExp, windowDefinition), offSetExpression));
}
return expressions;
}


private static Expression sum(List<Expression> expressions) {
return expressions.stream()
.reduce(Add::new)
.orElse(null);
}



private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) {
//required number of data points
expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context);
Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get();

//count data points function
expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context);
Expression countDataPointsFunction = context.popNamedParseExpressions().get();
//count data points window
WindowExpression countDataPointsWindow = new WindowExpression(
countDataPointsFunction,
trendlineWindow.windowSpec());

expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context);
Expression nullLiteral = context.popNamedParseExpressions().get();
Tuple2<Expression, Expression> nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>(
new LessThan(countDataPointsWindow, requiredNumberOfDataPoints),
nullLiteral
);
return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow));
}


private static NamedExpression getAlias(String name, Expression expression) {
return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression,
name,
NamedExpression.newExprId(),
seq(new java.util.ArrayList<String>()),
Option.empty(),
seq(new java.util.ArrayList<String>()));
}
}

0 comments on commit 5a8592f

Please sign in to comment.