From 0e63fec1b8d77a364e43d57a819d030ef749e4ca Mon Sep 17 00:00:00 2001 From: mariofusco Date: Tue, 17 Sep 2024 12:02:01 +0200 Subject: [PATCH] [DROOLS-7631] unify coercion checks between plain drl and executable model --- .../drlxparse/CoercedExpression.java | 21 ++++- .../drlxparse/CoercedExpressionTest.java | 4 +- .../drools/mvel/MVELConstraintBuilder.java | 80 +++---------------- .../integrationtests/DateCoercionTest.java | 36 +++++++++ .../java/org/drools/util/CoercionUtil.java | 51 +++++++++++- 5 files changed, 120 insertions(+), 72 deletions(-) diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpression.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpression.java index 4bbd12d9630..e26c1332734 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpression.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpression.java @@ -52,6 +52,8 @@ import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toJavaParserType; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toStringLiteral; import static org.drools.util.ClassUtils.toNonPrimitiveType; +import static org.drools.util.CoercionUtil.areComparisonCompatible; +import static org.drools.util.CoercionUtil.areEqualityCompatible; public class CoercedExpression { @@ -151,9 +153,22 @@ public CoercedExpressionResult coerce() { coercedLeft = left; } + checkCoercion(coercedLeft, coercedRight, leftClass, rightClass); return new CoercedExpressionResult(coercedLeft, coercedRight, rightAsStaticField); } + private void checkCoercion(TypedExpression coercedLeft, TypedExpression coercedRight, Class leftClass, Class rightClass) { + if (equalityExpr) { + if (!areEqualityCompatible(coercedLeft.getRawClass(), coercedRight.getRawClass())) { + throw new CoercedExpressionException(new InvalidExpressionErrorResult("Equality operation requires compatible types. Found " + leftClass + " and " + rightClass)); + } + } else { + if (!areComparisonCompatible(coercedLeft.getRawClass(), coercedRight.getRawClass())) { + throw new CoercedExpressionException(new InvalidExpressionErrorResult("Comparison operation requires compatible types. Found " + leftClass + " and " + rightClass)); + } + } + } + private boolean isBoolean(Class leftClass) { return Boolean.class.isAssignableFrom(leftClass) || boolean.class.isAssignableFrom(leftClass); } @@ -163,12 +178,14 @@ private boolean shouldCoerceBToMap() { } private boolean canCoerce() { - final Class leftClass = left.getRawClass(); + return canCoerce(left.getRawClass(), right.getRawClass()); + } + + private static boolean canCoerce(Class leftClass, Class rightClass) { if (!leftClass.isPrimitive() || !canCoerceLiteralNumberExpr(leftClass)) { return true; } - final Class rightClass = right.getRawClass(); return rightClass.isPrimitive() || Number.class.isAssignableFrom(rightClass) || Boolean.class == rightClass diff --git a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpressionTest.java b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpressionTest.java index 1264a096981..83e1a2c3a6d 100644 --- a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpressionTest.java +++ b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/generator/drlxparse/CoercedExpressionTest.java @@ -127,9 +127,9 @@ public void castToShort() { @Test public void castMaps() { final TypedExpression left = expr(THIS_PLACEHOLDER + ".getAge()", Integer.class); - final TypedExpression right = expr("$m.get(\"age\")", java.util.Map.class); + final TypedExpression right = expr("$m.get(\"age\")", Object.class); final CoercedExpression.CoercedExpressionResult coerce = new CoercedExpression(left, right, false).coerce(); - assertThat(coerce.getCoercedRight()).isEqualTo(expr("(java.lang.Integer)$m.get(\"age\")", Map.class)); + assertThat(coerce.getCoercedRight()).isEqualTo(expr("$m.get(\"age\")", Object.class)); } @Test diff --git a/drools-mvel/src/main/java/org/drools/mvel/MVELConstraintBuilder.java b/drools-mvel/src/main/java/org/drools/mvel/MVELConstraintBuilder.java index 38809896b71..2b9fb57ed10 100644 --- a/drools-mvel/src/main/java/org/drools/mvel/MVELConstraintBuilder.java +++ b/drools-mvel/src/main/java/org/drools/mvel/MVELConstraintBuilder.java @@ -97,6 +97,8 @@ import org.drools.mvel.expr.MVELObjectExpression; import org.drools.mvel.expr.MvelEvaluator; import org.drools.mvel.java.JavaForMvelDialectConfiguration; +import org.drools.util.CoercionUtil; +import org.drools.util.MethodUtils; import org.kie.api.definition.rule.Rule; import org.mvel2.ConversionHandler; import org.mvel2.DataConversion; @@ -426,8 +428,6 @@ private MVELCompilationUnit buildCompilationUnit( final RuleBuildContext context } } - MVELDialect dialect = (MVELDialect) context.getDialect( "mvel" ); - MVELCompilationUnit unit = null; try { @@ -439,16 +439,8 @@ private MVELCompilationUnit buildCompilationUnit( final RuleBuildContext context ((ClassObjectType) p.getObjectType()).getClassType() ); } - unit = dialect.getMVELCompilationUnit( (String) predicateDescr.getContent(), - analysis, - previousDeclarations, - localDeclarations, - null, - context, - "drools", - KnowledgeHelper.class, - context.isInXpath(), - MVELCompilationUnit.Scope.CONSTRAINT ); + unit = MVELDialect.getMVELCompilationUnit( (String) predicateDescr.getContent(), analysis, previousDeclarations, localDeclarations, + null, context, "drools", KnowledgeHelper.class, context.isInXpath(), MVELCompilationUnit.Scope.CONSTRAINT ); } catch ( final Exception e ) { copyErrorLocation(e, predicateDescr); context.addError( new DescrBuildError( context.getParentDescr(), @@ -486,48 +478,14 @@ private StringCoercionCompatibilityEvaluator() { } @Override public boolean areEqualityCompatible(Class c1, Class c2) { - if (c1 == NullType.class || c2 == NullType.class) { - return true; - } - if (c1 == String.class || c2 == String.class) { - return true; - } - Class boxed1 = convertFromPrimitiveType(c1); - Class boxed2 = convertFromPrimitiveType(c2); - if (boxed1.isAssignableFrom(boxed2) || boxed2.isAssignableFrom(boxed1)) { - return true; - } - if (Number.class.isAssignableFrom(boxed1) && Number.class.isAssignableFrom(boxed2)) { - return true; - } - if (areEqualityCompatibleEnums(boxed1, boxed2)) { - return true; - } - return !Modifier.isFinal(c1.getModifiers()) && !Modifier.isFinal(c2.getModifiers()); - } - - protected boolean areEqualityCompatibleEnums(final Class boxed1, - final Class boxed2) { - return boxed1.isEnum() && boxed2.isEnum() && boxed1.getName().equals(boxed2.getName()) - && equalEnumConstants(boxed1.getEnumConstants(), boxed2.getEnumConstants()); - } - - private boolean equalEnumConstants(final Object[] aa, - final Object[] bb) { - if (aa.length != bb.length) { - return false; - } - for (int i = 0; i < aa.length; i++) { - if (!Objects.equals(aa[i].getClass().getName(), bb[i].getClass().getName())) { - return false; - } - } - return true; + return CoercionUtil.areEqualityCompatible(c1 == NullType.class ? MethodUtils.NullType.class : c1, + c2 == NullType.class ? MethodUtils.NullType.class : c2); } @Override public boolean areComparisonCompatible(Class c1, Class c2) { - return areEqualityCompatible(c1, c2); + return CoercionUtil.areComparisonCompatible(c1 == NullType.class ? MethodUtils.NullType.class : c1, + c2 == NullType.class ? MethodUtils.NullType.class : c2); } } @@ -558,16 +516,8 @@ public TimerExpression buildTimerExpression( String expression, RuleBuildContext } Arrays.sort(previousDeclarations, SortDeclarations.instance); - MVELCompilationUnit unit = dialect.getMVELCompilationUnit( expression, - analysis, - previousDeclarations, - null, - null, - context, - "drools", - KnowledgeHelper.class, - false, - MVELCompilationUnit.Scope.EXPRESSION ); + MVELCompilationUnit unit = MVELDialect.getMVELCompilationUnit( expression, analysis, previousDeclarations, null, null, + context, "drools", KnowledgeHelper.class, false, MVELCompilationUnit.Scope.EXPRESSION ); MVELObjectExpression expr = new MVELObjectExpression( unit, dialect.getId() ); @@ -578,9 +528,7 @@ public TimerExpression buildTimerExpression( String expression, RuleBuildContext return expr; } catch ( final Exception e ) { AsmUtil.copyErrorLocation(e, context.getRuleDescr()); - context.addError( new DescrBuildError( context.getParentDescr(), - context.getRuleDescr(), - null, + context.addError( new DescrBuildError( context.getParentDescr(), context.getRuleDescr(), null, "Unable to build expression : " + e.getMessage() + "'" + expression + "'" ) ); return null; } finally { @@ -595,10 +543,8 @@ public AnalysisResult analyzeExpression(Class thisClass, String expr) { return analyzeExpression( expr, conf, new BoundIdentifiers( thisClass ) ); } - private static MVELAnalysisResult analyzeExpression(String expr, - ParserConfiguration conf, - BoundIdentifiers availableIdentifiers) { - if (expr.trim().length() == 0) { + private static MVELAnalysisResult analyzeExpression(String expr, ParserConfiguration conf, BoundIdentifiers availableIdentifiers) { + if (expr.trim().isEmpty()) { MVELAnalysisResult result = analyze( (Set ) Collections.EMPTY_SET, availableIdentifiers ); result.setMvelVariables( new HashMap<>() ); result.setTypesafe( true ); diff --git a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/mvel/integrationtests/DateCoercionTest.java b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/mvel/integrationtests/DateCoercionTest.java index 5d6a7157042..1dd24026b54 100644 --- a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/mvel/integrationtests/DateCoercionTest.java +++ b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/mvel/integrationtests/DateCoercionTest.java @@ -25,11 +25,14 @@ import org.drools.testcoverage.common.util.KieBaseTestConfiguration; import org.drools.testcoverage.common.util.KieBaseUtil; +import org.drools.testcoverage.common.util.KieUtil; import org.drools.testcoverage.common.util.TestParametersUtil; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.kie.api.KieBase; +import org.kie.api.builder.KieBuilder; +import org.kie.api.builder.Message; import org.kie.api.runtime.KieSession; import static org.assertj.core.api.Assertions.assertThat; @@ -176,4 +179,37 @@ public void testDateCoercionWithNestedOr() { assertThat(list.size()).isEqualTo(1); assertThat(list.get(0)).isEqualTo("working"); } + + @Test + public void testLocalDateTimeCoercion() { + // DROOLS-7631 + String drl = "import java.util.Date\n" + + "import java.time.LocalDateTime\n" + + "global java.util.List list\n" + + "declare DateContainer\n" + + " date: Date\n" + + "end\n" + + "declare LocalDateTimeContainer\n" + + " date: LocalDateTime\n" + + "end\n" + + "\n" + + "rule Init when\n" + + "then\n" + + " insert(new DateContainer(new Date( 1439882189744L )));" + + " insert(new LocalDateTimeContainer( LocalDateTime.now() ));" + + "end\n" + + "\n" + + "rule \"Test rule\"\n" + + "when\n" + + " DateContainer( $date: date )\n" + + " LocalDateTimeContainer( date > $date )\n" + + "then\n" + + " list.add(\"working\");\n" + + "end\n"; + + KieBuilder kieBuilder = KieUtil.getKieBuilderFromDrls(kieBaseTestConfiguration, false, drl); + List errors = kieBuilder.getResults().getMessages(Message.Level.ERROR); + assertThat(errors).hasSize(1); + assertThat(errors.get(0).getText()).contains("Comparison operation requires compatible types"); + } } diff --git a/drools-util/src/main/java/org/drools/util/CoercionUtil.java b/drools-util/src/main/java/org/drools/util/CoercionUtil.java index f91b2bd2358..bb7d5ac0e73 100644 --- a/drools-util/src/main/java/org/drools/util/CoercionUtil.java +++ b/drools-util/src/main/java/org/drools/util/CoercionUtil.java @@ -19,10 +19,14 @@ package org.drools.util; +import java.lang.reflect.Modifier; import java.math.BigDecimal; import java.math.BigInteger; +import java.time.chrono.ChronoLocalDateTime; +import java.time.temporal.Temporal; +import java.util.Objects; -import org.drools.util.MathUtils; +import static org.drools.util.ClassUtils.convertFromPrimitiveType; public class CoercionUtil { @@ -194,4 +198,49 @@ public static Number coerceToNumber(Number value, Class toClass) { } return ret; } + + public static boolean areEqualityCompatible(Class c1, Class c2) { + if (c1 == MethodUtils.NullType.class || c2 == MethodUtils.NullType.class) { + return true; + } + if (c1 == String.class || c2 == String.class) { + return true; + } + if (Temporal.class.isAssignableFrom(c1) && Temporal.class.isAssignableFrom(c2)) { + return true; + } + Class boxed1 = convertFromPrimitiveType(c1); + Class boxed2 = convertFromPrimitiveType(c2); + if (boxed1.isAssignableFrom(boxed2) || boxed2.isAssignableFrom(boxed1)) { + return true; + } + if (Number.class.isAssignableFrom(boxed1) && Number.class.isAssignableFrom(boxed2)) { + return true; + } + if (areEqualityCompatibleEnums(boxed1, boxed2)) { + return true; + } + return !Modifier.isFinal(c1.getModifiers()) && !Modifier.isFinal(c2.getModifiers()); + } + + protected static boolean areEqualityCompatibleEnums(Class boxed1, Class boxed2) { + return boxed1.isEnum() && boxed2.isEnum() && boxed1.getName().equals(boxed2.getName()) + && equalEnumConstants(boxed1.getEnumConstants(), boxed2.getEnumConstants()); + } + + private static boolean equalEnumConstants(Object[] aa, Object[] bb) { + if (aa.length != bb.length) { + return false; + } + for (int i = 0; i < aa.length; i++) { + if (!Objects.equals(aa[i].getClass().getName(), bb[i].getClass().getName())) { + return false; + } + } + return true; + } + + public static boolean areComparisonCompatible(Class c1, Class c2) { + return areEqualityCompatible(c1, c2); + } }