Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DROOLS-7631] unify coercion checks between plain drl and executable model #6086

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toJavaParserType;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toStringLiteral;
import static org.drools.util.ClassUtils.toNonPrimitiveType;
import static org.drools.util.CoercionUtil.areComparisonCompatible;
import static org.drools.util.CoercionUtil.areEqualityCompatible;

public class CoercedExpression {

Expand Down Expand Up @@ -151,9 +153,22 @@ public CoercedExpressionResult coerce() {
coercedLeft = left;
}

checkCoercion(coercedLeft, coercedRight, leftClass, rightClass);
return new CoercedExpressionResult(coercedLeft, coercedRight, rightAsStaticField);
}

private void checkCoercion(TypedExpression coercedLeft, TypedExpression coercedRight, Class<?> leftClass, Class<?> rightClass) {
if (equalityExpr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is known at creation time you can inline the check and put instead a lambda inside the expression? so you would not repeat the check for every invocation of the cohercion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this will reduce a bit the readability, and if you are suggesting this only for performance reason I'm pretty sure that invoking a bimorphic capturing lambda will be order of magnitude slower than accessing a local boolean variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance wise I am ignorant and you are probably right. My point is that if something that affects the behavious of an object is known at creation time of an object you should make the object behave accordingly. But I admit it that this is a minor point and we can keep it as it is. thank you.

if (!areEqualityCompatible(coercedLeft.getRawClass(), coercedRight.getRawClass())) {
throw new CoercedExpressionException(new InvalidExpressionErrorResult("Equality operation requires compatible types. Found " + leftClass + " and " + rightClass));
}
} else {
if (!areComparisonCompatible(coercedLeft.getRawClass(), coercedRight.getRawClass())) {
throw new CoercedExpressionException(new InvalidExpressionErrorResult("Comparison operation requires compatible types. Found " + leftClass + " and " + rightClass));
}
}
}

private boolean isBoolean(Class<?> leftClass) {
return Boolean.class.isAssignableFrom(leftClass) || boolean.class.isAssignableFrom(leftClass);
}
Expand All @@ -163,12 +178,14 @@ private boolean shouldCoerceBToMap() {
}

private boolean canCoerce() {
final Class<?> leftClass = left.getRawClass();
return canCoerce(left.getRawClass(), right.getRawClass());
}

private static boolean canCoerce(Class<?> leftClass, Class<?> rightClass) {
if (!leftClass.isPrimitive() || !canCoerceLiteralNumberExpr(leftClass)) {
return true;
}

final Class<?> rightClass = right.getRawClass();
return rightClass.isPrimitive()
|| Number.class.isAssignableFrom(rightClass)
|| Boolean.class == rightClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ public void castToShort() {
@Test
public void castMaps() {
final TypedExpression left = expr(THIS_PLACEHOLDER + ".getAge()", Integer.class);
final TypedExpression right = expr("$m.get(\"age\")", java.util.Map.class);
final TypedExpression right = expr("$m.get(\"age\")", Object.class);
final CoercedExpression.CoercedExpressionResult coerce = new CoercedExpression(left, right, false).coerce();
assertThat(coerce.getCoercedRight()).isEqualTo(expr("(java.lang.Integer)$m.get(\"age\")", Map.class));
assertThat(coerce.getCoercedRight()).isEqualTo(expr("$m.get(\"age\")", Object.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
import org.drools.mvel.expr.MVELObjectExpression;
import org.drools.mvel.expr.MvelEvaluator;
import org.drools.mvel.java.JavaForMvelDialectConfiguration;
import org.drools.util.CoercionUtil;
import org.drools.util.MethodUtils;
import org.kie.api.definition.rule.Rule;
import org.mvel2.ConversionHandler;
import org.mvel2.DataConversion;
Expand Down Expand Up @@ -426,8 +428,6 @@ private MVELCompilationUnit buildCompilationUnit( final RuleBuildContext context
}
}

MVELDialect dialect = (MVELDialect) context.getDialect( "mvel" );

MVELCompilationUnit unit = null;

try {
Expand All @@ -439,16 +439,8 @@ private MVELCompilationUnit buildCompilationUnit( final RuleBuildContext context
((ClassObjectType) p.getObjectType()).getClassType() );
}

unit = dialect.getMVELCompilationUnit( (String) predicateDescr.getContent(),
analysis,
previousDeclarations,
localDeclarations,
null,
context,
"drools",
KnowledgeHelper.class,
context.isInXpath(),
MVELCompilationUnit.Scope.CONSTRAINT );
unit = MVELDialect.getMVELCompilationUnit( (String) predicateDescr.getContent(), analysis, previousDeclarations, localDeclarations,
null, context, "drools", KnowledgeHelper.class, context.isInXpath(), MVELCompilationUnit.Scope.CONSTRAINT );
} catch ( final Exception e ) {
copyErrorLocation(e, predicateDescr);
context.addError( new DescrBuildError( context.getParentDescr(),
Expand Down Expand Up @@ -486,48 +478,14 @@ private StringCoercionCompatibilityEvaluator() { }

@Override
public boolean areEqualityCompatible(Class<?> c1, Class<?> c2) {
if (c1 == NullType.class || c2 == NullType.class) {
return true;
}
if (c1 == String.class || c2 == String.class) {
return true;
}
Class<?> boxed1 = convertFromPrimitiveType(c1);
Class<?> boxed2 = convertFromPrimitiveType(c2);
if (boxed1.isAssignableFrom(boxed2) || boxed2.isAssignableFrom(boxed1)) {
return true;
}
if (Number.class.isAssignableFrom(boxed1) && Number.class.isAssignableFrom(boxed2)) {
return true;
}
if (areEqualityCompatibleEnums(boxed1, boxed2)) {
return true;
}
return !Modifier.isFinal(c1.getModifiers()) && !Modifier.isFinal(c2.getModifiers());
}

protected boolean areEqualityCompatibleEnums(final Class<?> boxed1,
final Class<?> boxed2) {
return boxed1.isEnum() && boxed2.isEnum() && boxed1.getName().equals(boxed2.getName())
&& equalEnumConstants(boxed1.getEnumConstants(), boxed2.getEnumConstants());
}

private boolean equalEnumConstants(final Object[] aa,
final Object[] bb) {
if (aa.length != bb.length) {
return false;
}
for (int i = 0; i < aa.length; i++) {
if (!Objects.equals(aa[i].getClass().getName(), bb[i].getClass().getName())) {
return false;
}
}
return true;
return CoercionUtil.areEqualityCompatible(c1 == NullType.class ? MethodUtils.NullType.class : c1,
c2 == NullType.class ? MethodUtils.NullType.class : c2);
}

@Override
public boolean areComparisonCompatible(Class<?> c1, Class<?> c2) {
return areEqualityCompatible(c1, c2);
return CoercionUtil.areComparisonCompatible(c1 == NullType.class ? MethodUtils.NullType.class : c1,
c2 == NullType.class ? MethodUtils.NullType.class : c2);
}
}

Expand Down Expand Up @@ -558,16 +516,8 @@ public TimerExpression buildTimerExpression( String expression, RuleBuildContext
}
Arrays.sort(previousDeclarations, SortDeclarations.instance);

MVELCompilationUnit unit = dialect.getMVELCompilationUnit( expression,
analysis,
previousDeclarations,
null,
null,
context,
"drools",
KnowledgeHelper.class,
false,
MVELCompilationUnit.Scope.EXPRESSION );
MVELCompilationUnit unit = MVELDialect.getMVELCompilationUnit( expression, analysis, previousDeclarations, null, null,
context, "drools", KnowledgeHelper.class, false, MVELCompilationUnit.Scope.EXPRESSION );

MVELObjectExpression expr = new MVELObjectExpression( unit, dialect.getId() );

Expand All @@ -578,9 +528,7 @@ public TimerExpression buildTimerExpression( String expression, RuleBuildContext
return expr;
} catch ( final Exception e ) {
AsmUtil.copyErrorLocation(e, context.getRuleDescr());
context.addError( new DescrBuildError( context.getParentDescr(),
context.getRuleDescr(),
null,
context.addError( new DescrBuildError( context.getParentDescr(), context.getRuleDescr(), null,
"Unable to build expression : " + e.getMessage() + "'" + expression + "'" ) );
return null;
} finally {
Expand All @@ -595,10 +543,8 @@ public AnalysisResult analyzeExpression(Class<?> thisClass, String expr) {
return analyzeExpression( expr, conf, new BoundIdentifiers( thisClass ) );
}

private static MVELAnalysisResult analyzeExpression(String expr,
ParserConfiguration conf,
BoundIdentifiers availableIdentifiers) {
if (expr.trim().length() == 0) {
private static MVELAnalysisResult analyzeExpression(String expr, ParserConfiguration conf, BoundIdentifiers availableIdentifiers) {
if (expr.trim().isEmpty()) {
MVELAnalysisResult result = analyze( (Set<String> ) Collections.EMPTY_SET, availableIdentifiers );
result.setMvelVariables( new HashMap<>() );
result.setTypesafe( true );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@

import org.drools.testcoverage.common.util.KieBaseTestConfiguration;
import org.drools.testcoverage.common.util.KieBaseUtil;
import org.drools.testcoverage.common.util.KieUtil;
import org.drools.testcoverage.common.util.TestParametersUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.KieBase;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.Message;
import org.kie.api.runtime.KieSession;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -176,4 +179,37 @@ public void testDateCoercionWithNestedOr() {
assertThat(list.size()).isEqualTo(1);
assertThat(list.get(0)).isEqualTo("working");
}

@Test
public void testLocalDateTimeCoercion() {
// DROOLS-7631
String drl = "import java.util.Date\n" +
"import java.time.LocalDateTime\n" +
"global java.util.List list\n" +
"declare DateContainer\n" +
" date: Date\n" +
"end\n" +
"declare LocalDateTimeContainer\n" +
" date: LocalDateTime\n" +
"end\n" +
"\n" +
"rule Init when\n" +
"then\n" +
" insert(new DateContainer(new Date( 1439882189744L )));" +
" insert(new LocalDateTimeContainer( LocalDateTime.now() ));" +
"end\n" +
"\n" +
"rule \"Test rule\"\n" +
"when\n" +
" DateContainer( $date: date )\n" +
" LocalDateTimeContainer( date > $date )\n" +
"then\n" +
" list.add(\"working\");\n" +
"end\n";

KieBuilder kieBuilder = KieUtil.getKieBuilderFromDrls(kieBaseTestConfiguration, false, drl);
List<Message> errors = kieBuilder.getResults().getMessages(Message.Level.ERROR);
assertThat(errors).hasSize(1);
assertThat(errors.get(0).getText()).contains("Comparison operation requires compatible types");
}
}
51 changes: 50 additions & 1 deletion drools-util/src/main/java/org/drools/util/CoercionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

package org.drools.util;

import java.lang.reflect.Modifier;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.chrono.ChronoLocalDateTime;
import java.time.temporal.Temporal;
import java.util.Objects;

import org.drools.util.MathUtils;
import static org.drools.util.ClassUtils.convertFromPrimitiveType;

public class CoercionUtil {

Expand Down Expand Up @@ -194,4 +198,49 @@ public static Number coerceToNumber(Number value, Class<?> toClass) {
}
return ret;
}

public static boolean areEqualityCompatible(Class<?> c1, Class<?> c2) {
if (c1 == MethodUtils.NullType.class || c2 == MethodUtils.NullType.class) {
return true;
}
if (c1 == String.class || c2 == String.class) {
return true;
}
if (Temporal.class.isAssignableFrom(c1) && Temporal.class.isAssignableFrom(c2)) {
return true;
}
Class<?> boxed1 = convertFromPrimitiveType(c1);
Class<?> boxed2 = convertFromPrimitiveType(c2);
if (boxed1.isAssignableFrom(boxed2) || boxed2.isAssignableFrom(boxed1)) {
return true;
}
if (Number.class.isAssignableFrom(boxed1) && Number.class.isAssignableFrom(boxed2)) {
return true;
}
if (areEqualityCompatibleEnums(boxed1, boxed2)) {
return true;
}
return !Modifier.isFinal(c1.getModifiers()) && !Modifier.isFinal(c2.getModifiers());
}

protected static boolean areEqualityCompatibleEnums(Class<?> boxed1, Class<?> boxed2) {
return boxed1.isEnum() && boxed2.isEnum() && boxed1.getName().equals(boxed2.getName())
&& equalEnumConstants(boxed1.getEnumConstants(), boxed2.getEnumConstants());
}

private static boolean equalEnumConstants(Object[] aa, Object[] bb) {
if (aa.length != bb.length) {
return false;
}
for (int i = 0; i < aa.length; i++) {
if (!Objects.equals(aa[i].getClass().getName(), bb[i].getClass().getName())) {
return false;
}
}
return true;
}

public static boolean areComparisonCompatible(Class<?> c1, Class<?> c2) {
return areEqualityCompatible(c1, c2);
}
}
Loading