diff --git a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/DMNEvaluator.java b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/DMNEvaluator.java index 979f4159bd..92ec93303f 100644 --- a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/DMNEvaluator.java +++ b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/DMNEvaluator.java @@ -22,7 +22,9 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Optional; import org.kie.api.io.Resource; import org.kie.dmn.api.core.DMNContext; @@ -36,6 +38,7 @@ import org.kie.internal.io.ResourceFactory; import org.kie.kogito.jitexecutor.common.requests.MultipleResourcesPayload; import org.kie.kogito.jitexecutor.common.requests.ResourceWithURI; +import org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult; import org.kie.kogito.jitexecutor.dmn.utils.ResolveByKey; public class DMNEvaluator { @@ -47,6 +50,7 @@ public static DMNEvaluator fromXML(String modelXML) { Resource modelResource = ResourceFactory.newReaderResource(new StringReader(modelXML), "UTF-8"); DMNRuntime dmnRuntime = DMNRuntimeBuilder.fromDefaults().buildConfiguration() .fromResources(Collections.singletonList(modelResource)).getOrElseThrow(RuntimeException::new); + dmnRuntime.addListener(new JITDMNListener()); DMNModel dmnModel = dmnRuntime.getModels().get(0); return new DMNEvaluator(dmnModel, dmnRuntime); } @@ -73,9 +77,16 @@ public Collection getAllDMNModels() { return dmnRuntime.getModels(); } - public DMNResult evaluate(Map context) { - DMNContext dmnContext = new DynamicDMNContextBuilder(dmnRuntime.newContext(), dmnModel).populateContextWith(context); - return dmnRuntime.evaluateAll(dmnModel, dmnContext); + public JITDMNResult evaluate(Map context) { + DMNContext dmnContext = + new DynamicDMNContextBuilder(dmnRuntime.newContext(), dmnModel).populateContextWith(context); + DMNResult dmnResult = dmnRuntime.evaluateAll(dmnModel, dmnContext); + Optional> evaluationHitIds = dmnRuntime.getListeners().stream() + .filter(JITDMNListener.class::isInstance) + .findFirst() + .map(JITDMNListener.class::cast) + .map(JITDMNListener::getEvaluationHitIds); + return new JITDMNResult(getNamespace(), getName(), dmnResult, evaluationHitIds.orElse(Collections.emptyList())); } public static DMNEvaluator fromMultiple(MultipleResourcesPayload payload) { diff --git a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNListener.java b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNListener.java new file mode 100644 index 0000000000..9f753e7548 --- /dev/null +++ b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNListener.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.kie.kogito.jitexecutor.dmn; + +import java.util.ArrayList; +import java.util.List; + +import org.kie.dmn.api.core.event.AfterConditionalEvaluationEvent; +import org.kie.dmn.api.core.event.AfterEvaluateAllEvent; +import org.kie.dmn.api.core.event.AfterEvaluateBKMEvent; +import org.kie.dmn.api.core.event.AfterEvaluateContextEntryEvent; +import org.kie.dmn.api.core.event.AfterEvaluateDecisionEvent; +import org.kie.dmn.api.core.event.AfterEvaluateDecisionServiceEvent; +import org.kie.dmn.api.core.event.AfterEvaluateDecisionTableEvent; +import org.kie.dmn.api.core.event.AfterInvokeBKMEvent; +import org.kie.dmn.api.core.event.DMNEvent; +import org.kie.dmn.api.core.event.DMNRuntimeEventListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JITDMNListener implements DMNRuntimeEventListener { + + private final List evaluationHitIds = new ArrayList<>(); + + private static final Logger LOGGER = LoggerFactory.getLogger(JITDMNListener.class); + + @Override + public void afterEvaluateDecisionTable(AfterEvaluateDecisionTableEvent event) { + logEvent(event); + evaluationHitIds.addAll(event.getSelectedIds()); + } + + @Override + public void afterEvaluateDecision(AfterEvaluateDecisionEvent event) { + logEvent(event); + } + + @Override + public void afterEvaluateBKM(AfterEvaluateBKMEvent event) { + logEvent(event); + } + + @Override + public void afterEvaluateContextEntry(AfterEvaluateContextEntryEvent event) { + logEvent(event); + } + + @Override + public void afterEvaluateDecisionService(AfterEvaluateDecisionServiceEvent event) { + logEvent(event); + } + + @Override + public void afterInvokeBKM(AfterInvokeBKMEvent event) { + logEvent(event); + } + + @Override + public void afterEvaluateAll(AfterEvaluateAllEvent event) { + logEvent(event); + } + + @Override + public void afterConditionalEvaluation(AfterConditionalEvaluationEvent event) { + logEvent(event); + evaluationHitIds.add(event.getExecutedId()); + } + + public List getEvaluationHitIds() { + return evaluationHitIds; + } + + private void logEvent(DMNEvent toLog) { + LOGGER.info("{} event {}", toLog.getClass().getSimpleName(), toLog); + } + + private void logEvent(AfterConditionalEvaluationEvent toLog) { + LOGGER.info("{} event {}", toLog.getClass().getSimpleName(), toLog); + } + +} diff --git a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImpl.java b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImpl.java index 320f1a9609..bb0b6513c8 100644 --- a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImpl.java +++ b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImpl.java @@ -76,8 +76,7 @@ public JITDMNServiceImpl(int explainabilityLimeSampleSize, int explainabilityLim @Override public JITDMNResult evaluateModel(String modelXML, Map context) { DMNEvaluator dmnEvaluator = DMNEvaluator.fromXML(modelXML); - DMNResult dmnResult = dmnEvaluator.evaluate(context); - return new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult); + return dmnEvaluator.evaluate(context); } @Override diff --git a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/responses/JITDMNResult.java b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/responses/JITDMNResult.java index bfaea16420..4f43a123f6 100644 --- a/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/responses/JITDMNResult.java +++ b/jitexecutor/jitexecutor-dmn/src/main/java/org/kie/kogito/jitexecutor/dmn/responses/JITDMNResult.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -49,16 +50,23 @@ public class JITDMNResult implements Serializable, private Map decisionResults = new HashMap<>(); + private List evaluationHitIds; + public JITDMNResult() { // Intentionally blank. } public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult) { + this(namespace, modelName, dmnResult, Collections.emptyList()); + } + + public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult, List evaluationHitIds) { this.namespace = namespace; this.modelName = modelName; this.setDmnContext(dmnResult.getContext().getAll()); this.setMessages(dmnResult.getMessages()); this.setDecisionResults(dmnResult.getDecisionResults()); + this.evaluationHitIds = evaluationHitIds; } public String getNamespace() { @@ -102,6 +110,14 @@ public void setDecisionResults(List decisionResults } } + public List getEvaluationHitIds() { + return evaluationHitIds; + } + + public void setEvaluationHitIds(List evaluationHitIds) { + this.evaluationHitIds = evaluationHitIds; + } + @JsonIgnore @Override public DMNContext getContext() { @@ -151,6 +167,7 @@ public String toString() { .append(", dmnContext=").append(dmnContext) .append(", messages=").append(messages) .append(", decisionResults=").append(decisionResults) + .append(", evaluationHitIds=").append(evaluationHitIds) .append("]").toString(); } } diff --git a/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImplTest.java b/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImplTest.java index 3169dce379..e99957ef0c 100644 --- a/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImplTest.java +++ b/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/JITDMNServiceImplTest.java @@ -19,8 +19,11 @@ package org.kie.kogito.jitexecutor.dmn; import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.junit.jupiter.api.Assertions; @@ -43,7 +46,7 @@ public static void setup() throws IOException { } @Test - public void testModelEvaluation() { + void testModelEvaluation() { Map context = new HashMap<>(); context.put("FICO Score", 800); context.put("DTI Ratio", .1); @@ -57,7 +60,141 @@ public void testModelEvaluation() { } @Test - public void testExplainability() throws IOException { + void testDecisionTableModelEvaluation() throws IOException { + String decisionTableModel = getModelFromIoUtils("valid_models/DMNv1_x/LoanEligibility.dmn"); + Map client = new HashMap<>(); + client.put("age", 43); + client.put("salary", 1950); + client.put("existing payments", 100); + + Map loan = new HashMap<>(); + loan.put("duration", 15); + loan.put("installment", 180); + Map context = new HashMap<>(); + + context.put("Client", client); + context.put("Loan", loan); + context.put("SupremeDirector", "No"); + context.put("Bribe", 10); + JITDMNResult dmnResult = jitdmnService.evaluateModel(decisionTableModel, context); + + Assertions.assertEquals("LoanEligibility", dmnResult.getModelName()); + Assertions.assertEquals("https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example", dmnResult.getNamespace()); + Assertions.assertTrue(dmnResult.getMessages().isEmpty()); + Assertions.assertEquals("Yes", dmnResult.getDecisionResultByName("Eligibility").getResult()); + } + + @Test + void testEvaluationHitIds() throws IOException { + final String thenElementId = "_6481FF12-61B5-451C-B775-4143D9B6CD6B"; + final String elseElementId = "_2CD02CB2-6B56-45C4-B461-405E89D45633"; + final String ruleId0 = "_1578BD9E-2BF9-4BFC-8956-1A736959C937"; + final String ruleId1 = "_31CD7AA3-A806-4E7E-B512-821F82043620"; + final String ruleId3 = "_2545E1A8-93D3-4C8A-A0ED-8AD8B10A58F9"; + final String ruleId4 = "_510A50DA-D5A4-4F06-B0BE-7F8F2AA83740"; + String decisionTableModel = getModelFromIoUtils("valid_models/DMNv1_5/RiskScore_Simple.dmn"); + Map context = new HashMap<>(); + context.put("Credit Score", "Poor"); + context.put("DTI", 33); + JITDMNResult dmnResult = jitdmnService.evaluateModel(decisionTableModel, context); + + Assertions.assertEquals("DMN_A77074C1-21FE-4F7E-9753-F84661569AFC", dmnResult.getModelName()); + Assertions.assertTrue(dmnResult.getMessages().isEmpty()); + Assertions.assertEquals(BigDecimal.valueOf(50), dmnResult.getDecisionResultByName("Risk Score").getResult()); + List evaluationHitIds = dmnResult.getEvaluationHitIds(); + Assertions.assertNotNull(evaluationHitIds); + Assertions.assertEquals(3, evaluationHitIds.size()); + Assertions.assertTrue(evaluationHitIds.contains(elseElementId)); + Assertions.assertTrue(evaluationHitIds.contains(ruleId0)); + Assertions.assertTrue(evaluationHitIds.contains(ruleId3)); + + context = new HashMap<>(); + context.put("Credit Score", "Excellent"); + context.put("DTI", 10); + dmnResult = jitdmnService.evaluateModel(decisionTableModel, context); + + Assertions.assertTrue(dmnResult.getMessages().isEmpty()); + Assertions.assertEquals(BigDecimal.valueOf(20), dmnResult.getDecisionResultByName("Risk Score").getResult()); + evaluationHitIds = dmnResult.getEvaluationHitIds(); + Assertions.assertNotNull(evaluationHitIds); + Assertions.assertEquals(3, evaluationHitIds.size()); + Assertions.assertTrue(evaluationHitIds.contains(thenElementId)); + Assertions.assertTrue(evaluationHitIds.contains(ruleId1)); + Assertions.assertTrue(evaluationHitIds.contains(ruleId4)); + } + + @Test + void testConditionalWithNestedDecisionTableFromRiskScoreEvaluation() throws IOException { + final String thenElementId = "_6481FF12-61B5-451C-B775-4143D9B6CD6B"; + final String thenRuleId0 = "_D1753442-03F0-414B-94F8-6A86182DF6EB"; + final String thenRuleId4 = "_E787BA51-E31D-449B-A432-50BE7466A15E"; + final String elseElementId = "_2CD02CB2-6B56-45C4-B461-405E89D45633"; + final String elseRuleId2 = "_945A5471-9F91-4751-9D96-74978F6FB12B"; + final String elseRuleId5 = "_654BBFBC-9B84-4BD8-9D0B-13E8DD1B9F5D"; + String decisionTableModel = getModelFromIoUtils("valid_models/DMNv1_5/RiskScore_Conditional.dmn"); + + Map context = new HashMap<>(); + context.put("Credit Score", "Poor"); + context.put("DTI", 33); + context.put("World Region", "Asia"); + JITDMNResult dmnResult = jitdmnService.evaluateModel(decisionTableModel, context); + + Assertions.assertTrue(dmnResult.getMessages().isEmpty()); + Assertions.assertEquals(BigDecimal.valueOf(50), dmnResult.getDecisionResultByName("Risk Score").getResult()); + List evaluationHitIds = dmnResult.getEvaluationHitIds(); + Assertions.assertNotNull(evaluationHitIds); + Assertions.assertEquals(3, evaluationHitIds.size()); + Assertions.assertTrue(evaluationHitIds.contains(thenElementId)); + Assertions.assertTrue(evaluationHitIds.contains(thenRuleId0)); + Assertions.assertTrue(evaluationHitIds.contains(thenRuleId4)); + + context = new HashMap<>(); + context.put("Credit Score", "Excellent"); + context.put("DTI", 10); + context.put("World Region", "Europe"); + dmnResult = jitdmnService.evaluateModel(decisionTableModel, context); + + Assertions.assertTrue(dmnResult.getMessages().isEmpty()); + Assertions.assertEquals(BigDecimal.valueOf(30), dmnResult.getDecisionResultByName("Risk Score").getResult()); + evaluationHitIds = dmnResult.getEvaluationHitIds(); + Assertions.assertNotNull(evaluationHitIds); + Assertions.assertEquals(3, evaluationHitIds.size()); + Assertions.assertTrue(evaluationHitIds.contains(elseElementId)); + Assertions.assertTrue(evaluationHitIds.contains(elseRuleId2)); + Assertions.assertTrue(evaluationHitIds.contains(elseRuleId5)); + } + + @Test + void testMultipleHitRulesEvaluation() throws IOException { + final String rule0 = "_E5C380DA-AF7B-4401-9804-C58296EC09DD"; + final String rule1 = "_DFD65E8B-5648-4BFD-840F-8C76B8DDBD1A"; + final String rule2 = "_E80EE7F7-1C0C-4050-B560-F33611F16B05"; + String decisionTableModel = getModelFromIoUtils("valid_models/DMNv1_5/MultipleHitRules.dmn"); + + final List numbers = new ArrayList<>(); + numbers.add(BigDecimal.valueOf(10)); + numbers.add(BigDecimal.valueOf(2)); + numbers.add(BigDecimal.valueOf(1)); + final Map context = new HashMap<>(); + context.put("Numbers", numbers); + final JITDMNResult dmnResult = jitdmnService.evaluateModel(decisionTableModel, context); + + final List expectedStatistcs = new ArrayList<>(); + expectedStatistcs.add(BigDecimal.valueOf(6)); + expectedStatistcs.add(BigDecimal.valueOf(3)); + expectedStatistcs.add(BigDecimal.valueOf(1)); + Assertions.assertTrue(dmnResult.getMessages().isEmpty()); + Assertions.assertEquals(expectedStatistcs, dmnResult.getDecisionResultByName("Statistics").getResult()); + final List evaluationHitIds = dmnResult.getEvaluationHitIds(); + Assertions.assertNotNull(evaluationHitIds); + Assertions.assertEquals(6, evaluationHitIds.size()); + Assertions.assertEquals(3, evaluationHitIds.stream().filter(rule0::equals).count()); + Assertions.assertEquals(2, evaluationHitIds.stream().filter(rule1::equals).count()); + Assertions.assertEquals(1, evaluationHitIds.stream().filter(rule2::equals).count()); + } + + @Test + void testExplainability() throws IOException { String allTypesModel = getModelFromIoUtils("valid_models/DMNv1_x/allTypes.dmn"); Map context = new HashMap<>(); diff --git a/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/api/JITDMNResourceTest.java b/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/api/JITDMNResourceTest.java index d728d112c3..6c29603088 100644 --- a/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/api/JITDMNResourceTest.java +++ b/jitexecutor/jitexecutor-dmn/src/test/java/org/kie/kogito/jitexecutor/dmn/api/JITDMNResourceTest.java @@ -22,10 +22,17 @@ import java.util.HashMap; import java.util.Map; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.kie.kogito.jitexecutor.dmn.requests.JITDMNPayload; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; + import io.quarkus.test.junit.QuarkusTest; import io.restassured.http.ContentType; @@ -38,15 +45,25 @@ public class JITDMNResourceTest { private static String model; private static String modelWithExtensionElements; + private static String modelWithEvaluationHitIds; + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private static final String EVALUATION_HIT_IDS_FIELD_NAME = "evaluationHitIds"; + + static { + MAPPER.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + } @BeforeAll public static void setup() throws IOException { model = getModelFromIoUtils("invalid_models/DMNv1_x/test.dmn"); modelWithExtensionElements = getModelFromIoUtils("valid_models/DMNv1_x/testWithExtensionElements.dmn"); + modelWithEvaluationHitIds = getModelFromIoUtils("valid_models/DMNv1_5/RiskScore_Simple.dmn"); } @Test - public void testjitEndpoint() { + void testjitEndpoint() { JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext()); given() .contentType(ContentType.JSON) @@ -58,7 +75,7 @@ public void testjitEndpoint() { } @Test - public void testjitdmnResultEndpoint() { + void testjitdmnResultEndpoint() { JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext()); given() .contentType(ContentType.JSON) @@ -70,7 +87,34 @@ public void testjitdmnResultEndpoint() { } @Test - public void testjitExplainabilityEndpoint() { + void testjitdmnResultEndpointWithEvaluationHitIds() throws JsonProcessingException { + JITDMNPayload jitdmnpayload = new JITDMNPayload(modelWithEvaluationHitIds, buildRiskScoreContext()); + final String elseElementId = "_2CD02CB2-6B56-45C4-B461-405E89D45633"; + final String ruleId0 = "_1578BD9E-2BF9-4BFC-8956-1A736959C937"; + final String ruleId3 = "_2545E1A8-93D3-4C8A-A0ED-8AD8B10A58F9"; + String response = given().contentType(ContentType.JSON) + .body(jitdmnpayload) + .when().post("/jitdmn/dmnresult") + .then() + .statusCode(200) + .body(containsString("Risk Score"), + containsString("Loan Pre-Qualification"), + containsString(EVALUATION_HIT_IDS_FIELD_NAME), + containsString(elseElementId), + containsString(ruleId0), + containsString(ruleId3)) + .extract() + .asString(); + JsonNode retrieved = MAPPER.readTree(response); + ArrayNode evaluationHitIdsNode = (ArrayNode) retrieved.get(EVALUATION_HIT_IDS_FIELD_NAME); + Assertions.assertThat(evaluationHitIdsNode).hasSize(3) + .anyMatch(node -> node.asText().equals(elseElementId)) + .anyMatch(node -> node.asText().equals(ruleId0)) + .anyMatch(node -> node.asText().equals(ruleId3)); + } + + @Test + void testjitExplainabilityEndpoint() { JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext()); given() .contentType(ContentType.JSON) @@ -78,11 +122,12 @@ public void testjitExplainabilityEndpoint() { .when().post("/jitdmn/evaluateAndExplain") .then() .statusCode(200) - .body(containsString("dmnResult"), containsString("saliencies"), containsString("xls2dmn"), containsString("featureName")); + .body(containsString("dmnResult"), containsString("saliencies"), containsString("xls2dmn"), + containsString("featureName")); } @Test - public void testjitdmnWithExtensionElements() { + void testjitdmnWithExtensionElements() { Map context = new HashMap<>(); context.put("m", 1); context.put("n", 2); @@ -97,6 +142,13 @@ public void testjitdmnWithExtensionElements() { .body(containsString("m"), containsString("n"), containsString("sum")); } + private Map buildRiskScoreContext() { + Map context = new HashMap<>(); + context.put("Credit Score", "Poor"); + context.put("DTI", 33); + return context; + } + private Map buildContext() { Map context = new HashMap<>(); context.put("FICO Score", 800);