Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[9.101.x-prod] Enhance ConstraintOperator capability (#6104) #81

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,12 @@

public interface ConstraintOperator {
<T, V> BiPredicate<T, V> asPredicate();

default boolean hasIndex() {
return false;
}

default Index.ConstraintType getIndexType() {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.drools.model.codegen.execmodel;

import java.util.function.BiPredicate;

import org.drools.base.common.NetworkNode;
import org.drools.base.prototype.PrototypeObjectType;
import org.drools.base.rule.IndexableConstraint;
import org.drools.core.reteoo.AlphaNode;
import org.drools.core.reteoo.BetaNode;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.kiesession.rulebase.InternalKnowledgeBase;
import org.drools.model.ConstraintOperator;
import org.drools.model.Index;
import org.drools.model.Model;
import org.drools.model.Rule;
import org.drools.model.codegen.execmodel.domain.Result;
import org.drools.model.impl.ModelImpl;
import org.drools.model.prototype.PrototypeVariable;
import org.drools.modelcompiler.KieBaseBuilder;
import org.drools.modelcompiler.constraints.LambdaConstraint;
import org.junit.Test;
import org.kie.api.KieBase;
import org.kie.api.prototype.PrototypeFact;
import org.kie.api.prototype.PrototypeFactInstance;
import org.kie.api.runtime.KieSession;

import static org.assertj.core.api.Assertions.assertThat;
import static org.drools.model.DSL.on;
import static org.drools.model.PatternDSL.rule;
import static org.drools.model.prototype.PrototypeDSL.protoPattern;
import static org.drools.model.prototype.PrototypeDSL.variable;
import static org.drools.model.prototype.PrototypeExpression.fixedValue;
import static org.drools.model.prototype.PrototypeExpression.prototypeField;
import static org.kie.api.prototype.PrototypeBuilder.prototype;

public class CustomConstraintOperatorTest {

static class CustomConstraintOperator implements ConstraintOperator {

public int counter = 0;

@Override
public <T, V> BiPredicate<T, V> asPredicate() {
return (t, v) -> {
counter++;
return t.equals(v);
};
}

@Override
public boolean hasIndex() {
return true;
}

@Override
public Index.ConstraintType getIndexType() {
return Index.ConstraintType.EQUAL;
}

@Override
public String toString() {
return Index.ConstraintType.EQUAL.toString();
}
}

@Test
public void alphaIndexIneffective() {
CustomConstraintOperator customConstraintOperator = new CustomConstraintOperator();

PrototypeFact testPrototype = prototype("test").asFact();
PrototypeVariable testV = variable(testPrototype);

Rule rule1 = rule("alpha1")
.build(
protoPattern(testV)
.expr(prototypeField("fieldA"), customConstraintOperator, fixedValue(1)),
on(testV).execute((drools, x) ->
drools.insert(new Result("Found"))
)
);

Model model = new ModelImpl().addRule(rule1);
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
KieSession ksession = kieBase.newKieSession();

PrototypeFactInstance testFact = testPrototype.newInstance();
testFact.put("fieldA", 1);

ksession.insert(testFact);
assertThat(ksession.fireAllRules()).isEqualTo(1);

// Index is created, but actual alpha index hashing works only with more than 3 nodes
Index index = getFirstAlphaNodeIndex((InternalKnowledgeBase) kieBase, testPrototype);
assertThat(index.getIndexType()).isEqualTo(Index.IndexType.ALPHA);
assertThat(index.getConstraintType()).isEqualTo(Index.ConstraintType.EQUAL);

// alpha index hashing is not effective, so the predicated is called
assertThat(customConstraintOperator.counter).isEqualTo(1);
}

@Test
public void alphaIndexEffective() {
CustomConstraintOperator customConstraintOperator = new CustomConstraintOperator();

PrototypeFact testPrototype = prototype("test").asFact();
PrototypeVariable testV = variable(testPrototype);

Rule rule1 = rule("alpha1")
.build(
protoPattern(testV)
.expr(prototypeField("fieldA"), customConstraintOperator, fixedValue(1)),
on(testV).execute((drools, x) ->
drools.insert(new Result("Found"))
)
);
Rule rule2 = rule("alpha2")
.build(
protoPattern(testV)
.expr(prototypeField("fieldA"), customConstraintOperator, fixedValue(2)),
on(testV).execute((drools, x) ->
drools.insert(new Result("Found"))
)
);
Rule rule3 = rule("alpha3")
.build(
protoPattern(testV)
.expr(prototypeField("fieldA"), customConstraintOperator, fixedValue(3)),
on(testV).execute((drools, x) ->
drools.insert(new Result("Found"))
)
);

Model model = new ModelImpl().addRule(rule1).addRule(rule2).addRule(rule3);
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
KieSession ksession = kieBase.newKieSession();

PrototypeFactInstance testFact = testPrototype.newInstance();
testFact.put("fieldA", 1);

ksession.insert(testFact);
assertThat(ksession.fireAllRules()).isEqualTo(1);

Index index = getFirstAlphaNodeIndex((InternalKnowledgeBase) kieBase, testPrototype);
assertThat(index.getIndexType()).isEqualTo(Index.IndexType.ALPHA);
assertThat(index.getConstraintType()).isEqualTo(Index.ConstraintType.EQUAL);

// alpha index hashing is effective, so the predicated is not called
assertThat(customConstraintOperator.counter).isZero();
}

private static Index getFirstAlphaNodeIndex(InternalKnowledgeBase kieBase, PrototypeFact testPrototype) {
EntryPointNode epn = kieBase.getRete().getEntryPointNodes().values().iterator().next();
ObjectTypeNode otn = epn.getObjectTypeNodes().get(new PrototypeObjectType(testPrototype));
AlphaNode alphaNode = (AlphaNode) otn.getObjectSinkPropagator().getSinks()[0];
IndexableConstraint constraint = (IndexableConstraint) alphaNode.getConstraint();
return ((LambdaConstraint) constraint).getEvaluator().getIndex();
}

@Test
public void betaIndex() {
CustomConstraintOperator customConstraintOperator = new CustomConstraintOperator();

Result result = new Result();

PrototypeFact personFact = prototype("org.drools.Person").withField("name").withField("age").asFact();

PrototypeVariable markV = variable(personFact);
PrototypeVariable ageMateV = variable(personFact);

Rule rule = rule("beta")
.build(
protoPattern(markV)
.expr("name", Index.ConstraintType.EQUAL, "Mark"),
protoPattern(ageMateV)
.expr("name", Index.ConstraintType.NOT_EQUAL, "Mark")
.expr("age", customConstraintOperator, markV, "age"),
on(ageMateV, markV).execute((p1, p2) -> result.setValue(p1.get("name") + " is the same age as " + p2.get("name")))
);

Model model = new ModelImpl().addRule(rule);
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);

KieSession ksession = kieBase.newKieSession();

PrototypeFactInstance mark = personFact.newInstance();
mark.put("name", "Mark");
mark.put("age", 37);

PrototypeFactInstance john = personFact.newInstance();
john.put("name", "John");
john.put("age", 39);

PrototypeFactInstance paul = personFact.newInstance();
paul.put("name", "Paul");
paul.put("age", 37);

ksession.insert(mark);
ksession.insert(john);
ksession.insert(paul);

ksession.fireAllRules();
assertThat(result.getValue()).isEqualTo("Paul is the same age as Mark");

Index index = getFirstBetaNodeIndex((InternalKnowledgeBase) kieBase, personFact);
assertThat(index.getIndexType()).isEqualTo(Index.IndexType.BETA);
assertThat(index.getConstraintType()).isEqualTo(Index.ConstraintType.EQUAL);

// When beta index is used, the predicate in the custom operator is not actually called
assertThat(customConstraintOperator.counter).isZero();
}

private static Index getFirstBetaNodeIndex(InternalKnowledgeBase kieBase, PrototypeFact testPrototype) {
EntryPointNode epn = kieBase.getRete().getEntryPointNodes().values().iterator().next();
ObjectTypeNode otn = epn.getObjectTypeNodes().get(new PrototypeObjectType(testPrototype));
NetworkNode[] sinks = otn.getObjectSinkPropagator().getSinks();
BetaNode betaNode = findBetaNode(sinks);

IndexableConstraint constraint = (IndexableConstraint) betaNode.getConstraints()[0];
return ((LambdaConstraint) constraint).getEvaluator().getIndex();
}

private static BetaNode findBetaNode(NetworkNode[] sinks) {
for (NetworkNode sink : sinks) {
if (sink instanceof BetaNode) {
return (BetaNode) sink;
} else {
BetaNode betaNode = findBetaNode(sink.getSinks());
if (betaNode != null) {
return betaNode;
}
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,22 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

// If operator is not Index.ConstraintType, it may contain Index.ConstraintType internally for indexing purposes
ConstraintOperator operatorAsIndexType = operator;
if (operator.hasIndex()) {
operatorAsIndexType = operator.getIndexType();
}

expr(createExprId(left, operator, right),
asPredicate1(leftExtractor, operator, right.asFunction(prototype)),
createAlphaIndex(left, operator, right, prototype, leftExtractor),
createAlphaIndex(left, operatorAsIndexType, right, prototype, leftExtractor),
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );

return this;
}

private static AlphaIndex createAlphaIndex(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right, Prototype prototype, Function1<PrototypeFactInstance, Object> leftExtractor) {
if (left.getIndexingKey().isPresent() && right instanceof PrototypeExpression.FixedValue && operator instanceof Index.ConstraintType constraintType) {
private static AlphaIndex createAlphaIndex(PrototypeExpression left, ConstraintOperator operatorAsIndexType, PrototypeExpression right, Prototype prototype, Function1<PrototypeFactInstance, Object> leftExtractor) {
if (left.getIndexingKey().isPresent() && right instanceof PrototypeExpression.FixedValue && operatorAsIndexType instanceof Index.ConstraintType constraintType) {
String fieldName = left.getIndexingKey().get();
Prototype.Field field = prototype.getField(fieldName);
Object value = ((PrototypeExpression.FixedValue) right).getValue();
Expand Down Expand Up @@ -167,9 +173,15 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

// If operator is not Index.ConstraintType, it may contain Index.ConstraintType internally for indexing purposes
ConstraintOperator operatorAsIndexType = operator;
if (operator.hasIndex()) {
operatorAsIndexType = operator.getIndexType();
}

expr(createExprId(left, operator, right),
other, asPredicate2(left.asFunction(prototype), operator, right.asFunction(otherPrototype)),
createBetaIndex(left, operator, right, prototype, otherPrototype),
createBetaIndex(left, operatorAsIndexType, right, prototype, otherPrototype),
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );

return this;
Expand All @@ -181,8 +193,8 @@ private static String createExprId(PrototypeExpression left, ConstraintOperator
return "expr:" + leftId + ":" + operator + ":" + rightId;
}

private BetaIndex createBetaIndex(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right, Prototype prototype, Prototype otherPrototype) {
if (left.getIndexingKey().isPresent() && operator instanceof Index.ConstraintType constraintType && right.getIndexingKey().isPresent()) {
private BetaIndex createBetaIndex(PrototypeExpression left, ConstraintOperator operatorAsIndexType, PrototypeExpression right, Prototype prototype, Prototype otherPrototype) {
if (left.getIndexingKey().isPresent() && operatorAsIndexType instanceof Index.ConstraintType constraintType && right.getIndexingKey().isPresent()) {
String fieldName = left.getIndexingKey().get();
Prototype.Field field = prototype.getField(fieldName);
Function1<PrototypeFactInstance, Object> extractor = left.asFunction(prototype);
Expand Down
Loading