Skip to content

Commit

Permalink
[KIE-775] drools executable-model fails with a bind variable to a cal…
Browse files Browse the repository at this point in the history
…culation result of int and BigDecimal (apache#5636)
  • Loading branch information
tkobayas committed Jan 12, 2024
1 parent 40fcd42 commit 2501a0b
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right,
this.operator = operator;
}

/*
* This coercion only deals with String vs Numeric types.
* BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall()
*/
public ArithmeticCoercedExpressionResult coerce() {

if (!requiresCoercion()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
import static org.drools.modelcompiler.builder.generator.DslMethodNames.createDslTopLevelMethod;
import static org.drools.modelcompiler.builder.generator.drlxparse.MultipleDrlxParseSuccess.createMultipleDrlxParseSuccess;
import static org.drools.modelcompiler.builder.generator.drlxparse.SpecialComparisonCase.specialComparisonFactory;
import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall;
import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion;
import static org.drools.modelcompiler.builder.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall;
import static org.drools.modelcompiler.builder.generator.expressiontyper.FlattenScope.transformFullyQualifiedInlineCastExpr;
import static org.drools.mvel.parser.printer.PrintUtil.printNode;
import static org.drools.mvel.parser.utils.AstUtils.isLogicalOperator;
Expand Down Expand Up @@ -195,11 +198,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes
}

private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) {
DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() );
DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult));
if (drlx.getExpr() instanceof NameExpr) {
decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) );
} else if (drlx.getExpr() instanceof BinaryExpr) {
decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft()));
Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr());
decl.setBoundVariable(PrintUtil.printNode(leftMostExpression));
if (singleResult.getExpr() instanceof MethodCallExpr) {
// BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr
ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext();
ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext);
TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression);
Optional<TypedExpression> optLeft = leftTypedExpressionResult.getTypedExpression();
if (!optLeft.isPresent()) {
throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft());
}
singleResult.setBoundExpr(optLeft.get());
}
}
decl.setBelongingPatternDescr(context.getCurrentPatternDescr());
singleResult.setExprBinding( bindId );
Expand All @@ -209,6 +224,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe
}
}

private static Class<?> getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) {
if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) {
// in case of enclosed, bind type should be the calculation result type
// If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior
return (Class<?>)singleResult.getExprType();
} else {
return singleResult.getLeftExprTypeBeforeCoercion();
}
}

private Expression getLeftMostExpression(BinaryExpr binaryExpr) {
Expression left = binaryExpr.getLeft();
if (left instanceof BinaryExpr) {
return getLeftMostExpression((BinaryExpr) left);
}
return left;
}

/*
This is the entry point for Constraint Transformation from a parsed MVEL constraint
to a Java Expression
Expand Down Expand Up @@ -650,17 +683,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class<?> patternT

Expression combo;

boolean arithmeticExpr = isArithmeticOperator(operator);
boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext );
boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint;

Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType();

if (equalityExpr) {
combo = getEqualityExpression( left, right, operator ).expression;
} else if (arithmeticExpr && (left.isBigDecimal())) {
ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType));
CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr);

combo = compiledExpressionResult.getExpression();
} else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context);
exprType = getBinaryTypeAfterConversion(left.getType(), right.getType());
} else {
if (left.getExpression() == null || right.getExpression() == null) {
return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr));
Expand Down Expand Up @@ -688,7 +720,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class<?> patternT
constraintType = Index.ConstraintType.FORALL_SELF_JOIN;
}

return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType())
return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType)
.setDecodeConstraintType( constraintType )
.setUsedDeclarations( expressionTyperContext.getUsedDeclarations() )
.setUsedDeclarationsOnLeft( usedDeclarationsOnLeft )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,14 @@ private Optional<TypedExpression> toTypedExpressionRec(Expression drlxExpr) {
right = coerced.getCoercedRight();

final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);
return of(new TypedExpression(combo, left.getType()));

if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext);
java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType());
return of(new TypedExpression(expression, binaryType));
} else {
return of(new TypedExpression(combo, left.getType()));
}
}

if (drlxExpr instanceof HalfBinaryExpr) {
Expand Down Expand Up @@ -809,11 +816,12 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) {
TypedExpression rightTypedExpression = right.getTypedExpression()
.orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!"));
binaryExpr.setRight(rightTypedExpression.getExpression());
java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator());
if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), binaryType)) {
if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) {
Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext);
java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType());
return new TypedExpressionCursor(compiledExpression, binaryType);
} else {
java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator());
return new TypedExpressionCursor(binaryExpr, binaryType);
}
}
Expand All @@ -822,7 +830,7 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) {
* Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler.
* This method can be generic, so we may centralize the calls in drools-model
*/
private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional<Class<?>> originalPatternType, RuleContext ruleContext) {
public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional<Class<?>> originalPatternType, RuleContext ruleContext) {
ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType);
CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr));
return compiledExpressionResult.getExpression();
Expand All @@ -831,8 +839,15 @@ private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryE
/*
* BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger.
*/
private static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type type) {
return isArithmeticOperator(operator) && type.equals(BigDecimal.class);
public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) {
return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class));
}

/*
* After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger.
*/
public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) {
return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType;
}

private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) {
Expand Down Expand Up @@ -939,6 +954,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[]
Expression argumentExpression = methodCallExpr.getArgument(i);

if (argumentType != actualArgumentType) {
// unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes
// It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr()
methodCallExpr.replace(argumentExpression, new NameExpr("placeholder"));
Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression);
methodCallExpr.setArgument(i, coercedExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,4 +742,45 @@ public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() {
public static String intToString(int value) {
return Integer.toString(value);
}

@Test
public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : (1000 * value1)");
}

@Test
public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)");
}

@Test
public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)");
}

private void bindVariableToBigDecimalCoercion(String binding) {
// KIE-775
String str =
"package org.drools.modelcompiler.bigdecimals\n" +
"import " + BDFact.class.getCanonicalName() + ";\n" +
"global java.util.List result;\n" +
"rule R1\n" +
" when\n" +
" BDFact( " + binding + " )\n" +
" then\n" +
" result.add($var);\n" +
"end";

KieSession ksession = getKieSession(str);
List<BigDecimal> result = new ArrayList<>();
ksession.setGlobal("result", result);

BDFact bdFact = new BDFact();
bdFact.setValue1(new BigDecimal("80"));

ksession.insert(bdFact);
ksession.fireAllRules();

assertThat(result).contains(new BigDecimal("80000"));
}
}

0 comments on commit 2501a0b

Please sign in to comment.