Skip to content

Commit

Permalink
[incubator-kie-issues#1591] Aggregate evaluationHitIds to Map<String,…
Browse files Browse the repository at this point in the history
… Integer> (apache#2134)

* [incubator-kie-issues#1591] Aggregate evaluationHitIds to Map<String, Integer>

* [incubator-kie-issues#1591] Fixed as per PR suggestion

---------

Co-authored-by: Gabriele-Cardosi <gabriele.cardosi@ibm.com>
  • Loading branch information
2 people authored and rgdoliveira committed Nov 7, 2024
1 parent 1e9edd2 commit 262af48
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -81,12 +80,12 @@ public JITDMNResult evaluate(Map<String, Object> context) {
DMNContext dmnContext =
new DynamicDMNContextBuilder(dmnRuntime.newContext(), dmnModel).populateContextWith(context);
DMNResult dmnResult = dmnRuntime.evaluateAll(dmnModel, dmnContext);
Optional<List<String>> evaluationHitIds = dmnRuntime.getListeners().stream()
Optional<Map<String, Integer>> evaluationHitIds = dmnRuntime.getListeners().stream()
.filter(JITDMNListener.class::isInstance)
.findFirst()
.map(JITDMNListener.class::cast)
.map(JITDMNListener::getEvaluationHitIds);
return new JITDMNResult(getNamespace(), getName(), dmnResult, evaluationHitIds.orElse(Collections.emptyList()));
return new JITDMNResult(getNamespace(), getName(), dmnResult, evaluationHitIds.orElse(Collections.emptyMap()));
}

public static DMNEvaluator fromMultiple(MultipleResourcesPayload payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
*/
package org.kie.kogito.jitexecutor.dmn;

import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
import java.util.Map;

import org.kie.dmn.api.core.event.AfterConditionalEvaluationEvent;
import org.kie.dmn.api.core.event.AfterEvaluateAllEvent;
Expand All @@ -36,14 +36,14 @@

public class JITDMNListener implements DMNRuntimeEventListener {

private final List<String> evaluationHitIds = new ArrayList<>();
private final Map<String, Integer> evaluationHitIds = new HashMap<>();

private static final Logger LOGGER = LoggerFactory.getLogger(JITDMNListener.class);

@Override
public void afterEvaluateDecisionTable(AfterEvaluateDecisionTableEvent event) {
logEvent(event);
evaluationHitIds.addAll(event.getSelectedIds());
event.getSelectedIds().forEach(s -> evaluationHitIds.compute(s, (k, v) -> v == null ? 1 : v + 1));
}

@Override
Expand Down Expand Up @@ -79,10 +79,10 @@ public void afterEvaluateAll(AfterEvaluateAllEvent event) {
@Override
public void afterConditionalEvaluation(AfterConditionalEvaluationEvent event) {
logEvent(event);
evaluationHitIds.add(event.getExecutedId());
evaluationHitIds.compute(event.getExecutedId(), (k, v) -> v == null ? 1 : v + 1);
}

public List<String> getEvaluationHitIds() {
public Map<String, Integer> getEvaluationHitIds() {
return evaluationHitIds;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ public class JITDMNResult implements Serializable,

private Map<String, JITDMNDecisionResult> decisionResults = new HashMap<>();

private List<String> evaluationHitIds;
private Map<String, Integer> evaluationHitIds;

public JITDMNResult() {
// Intentionally blank.
}

public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult) {
this(namespace, modelName, dmnResult, Collections.emptyList());
this(namespace, modelName, dmnResult, Collections.emptyMap());
}

public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult, List<String> evaluationHitIds) {
public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult, Map<String, Integer> evaluationHitIds) {
this.namespace = namespace;
this.modelName = modelName;
this.setDmnContext(dmnResult.getContext().getAll());
Expand Down Expand Up @@ -110,11 +110,11 @@ public void setDecisionResults(List<? extends DMNDecisionResult> decisionResults
}
}

public List<String> getEvaluationHitIds() {
public Map<String, Integer> getEvaluationHitIds() {
return evaluationHitIds;
}

public void setEvaluationHitIds(List<String> evaluationHitIds) {
public void setEvaluationHitIds(Map<String, Integer> evaluationHitIds) {
this.evaluationHitIds = evaluationHitIds;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ void testDecisionTableModelEvaluation() throws IOException {
JITDMNResult dmnResult = jitdmnService.evaluateModel(decisionTableModel, context);

Assertions.assertEquals("LoanEligibility", dmnResult.getModelName());
Assertions.assertEquals("https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example", dmnResult.getNamespace());
Assertions.assertEquals("https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example",
dmnResult.getNamespace());
Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals("Yes", dmnResult.getDecisionResultByName("Eligibility").getResult());
}
Expand All @@ -101,12 +102,12 @@ void testEvaluationHitIds() throws IOException {
Assertions.assertEquals("DMN_A77074C1-21FE-4F7E-9753-F84661569AFC", dmnResult.getModelName());
Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals(BigDecimal.valueOf(50), dmnResult.getDecisionResultByName("Risk Score").getResult());
List<String> evaluationHitIds = dmnResult.getEvaluationHitIds();
Map<String, Integer> evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(elseElementId));
Assertions.assertTrue(evaluationHitIds.contains(ruleId0));
Assertions.assertTrue(evaluationHitIds.contains(ruleId3));
Assertions.assertTrue(evaluationHitIds.containsKey(elseElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId0));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId3));

context = new HashMap<>();
context.put("Credit Score", "Excellent");
Expand All @@ -118,9 +119,9 @@ void testEvaluationHitIds() throws IOException {
evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(thenElementId));
Assertions.assertTrue(evaluationHitIds.contains(ruleId1));
Assertions.assertTrue(evaluationHitIds.contains(ruleId4));
Assertions.assertTrue(evaluationHitIds.containsKey(thenElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId1));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId4));
}

@Test
Expand All @@ -141,12 +142,12 @@ void testConditionalWithNestedDecisionTableFromRiskScoreEvaluation() throws IOEx

Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals(BigDecimal.valueOf(50), dmnResult.getDecisionResultByName("Risk Score").getResult());
List<String> evaluationHitIds = dmnResult.getEvaluationHitIds();
Map<String, Integer> evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(thenElementId));
Assertions.assertTrue(evaluationHitIds.contains(thenRuleId0));
Assertions.assertTrue(evaluationHitIds.contains(thenRuleId4));
Assertions.assertTrue(evaluationHitIds.containsKey(thenElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(thenRuleId0));
Assertions.assertTrue(evaluationHitIds.containsKey(thenRuleId4));

context = new HashMap<>();
context.put("Credit Score", "Excellent");
Expand All @@ -159,9 +160,9 @@ void testConditionalWithNestedDecisionTableFromRiskScoreEvaluation() throws IOEx
evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(elseElementId));
Assertions.assertTrue(evaluationHitIds.contains(elseRuleId2));
Assertions.assertTrue(evaluationHitIds.contains(elseRuleId5));
Assertions.assertTrue(evaluationHitIds.containsKey(elseElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(elseRuleId2));
Assertions.assertTrue(evaluationHitIds.containsKey(elseRuleId5));
}

@Test
Expand All @@ -185,12 +186,12 @@ void testMultipleHitRulesEvaluation() throws IOException {
expectedStatistcs.add(BigDecimal.valueOf(1));
Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals(expectedStatistcs, dmnResult.getDecisionResultByName("Statistics").getResult());
final List<String> evaluationHitIds = dmnResult.getEvaluationHitIds();
final Map<String, Integer> evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(6, evaluationHitIds.size());
Assertions.assertEquals(3, evaluationHitIds.stream().filter(rule0::equals).count());
Assertions.assertEquals(2, evaluationHitIds.stream().filter(rule1::equals).count());
Assertions.assertEquals(1, evaluationHitIds.stream().filter(rule2::equals).count());
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertEquals(3, evaluationHitIds.get(rule0));
Assertions.assertEquals(2, evaluationHitIds.get(rule1));
Assertions.assertEquals(1, evaluationHitIds.get(rule2));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
package org.kie.kogito.jitexecutor.dmn.api;

import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.assertj.core.api.Assertions;
Expand All @@ -31,7 +34,7 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;

import io.quarkus.test.junit.QuarkusTest;
import io.restassured.http.ContentType;
Expand All @@ -43,9 +46,9 @@
@QuarkusTest
public class JITDMNResourceTest {

private static String model;
private static String invalidModel;
private static String modelWithExtensionElements;
private static String modelWithEvaluationHitIds;
private static String modelWithMultipleEvaluationHitIds;

private static final ObjectMapper MAPPER = new ObjectMapper();

Expand All @@ -57,14 +60,14 @@ public class JITDMNResourceTest {

@BeforeAll
public static void setup() throws IOException {
model = getModelFromIoUtils("invalid_models/DMNv1_x/test.dmn");
invalidModel = getModelFromIoUtils("invalid_models/DMNv1_x/test.dmn");
modelWithExtensionElements = getModelFromIoUtils("valid_models/DMNv1_x/testWithExtensionElements.dmn");
modelWithEvaluationHitIds = getModelFromIoUtils("valid_models/DMNv1_5/RiskScore_Simple.dmn");
modelWithMultipleEvaluationHitIds = getModelFromIoUtils("valid_models/DMNv1_5/MultipleHitRules.dmn");
}

@Test
void testjitEndpoint() {
JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext());
JITDMNPayload jitdmnpayload = new JITDMNPayload(invalidModel, buildContext());
given()
.contentType(ContentType.JSON)
.body(jitdmnpayload)
Expand All @@ -76,46 +79,44 @@ void testjitEndpoint() {

@Test
void testjitdmnResultEndpoint() {
JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext());
JITDMNPayload jitdmnpayload = new JITDMNPayload(modelWithMultipleEvaluationHitIds, buildMultipleHitContext());
given()
.contentType(ContentType.JSON)
.body(jitdmnpayload)
.when().post("/jitdmn/dmnresult")
.then()
.statusCode(200)
.body(containsString("Loan Approval"), containsString("Approved"), containsString("xls2dmn"));
.body(containsString("Statistics"));
}

@Test
void testjitdmnResultEndpointWithEvaluationHitIds() throws JsonProcessingException {
JITDMNPayload jitdmnpayload = new JITDMNPayload(modelWithEvaluationHitIds, buildRiskScoreContext());
final String elseElementId = "_2CD02CB2-6B56-45C4-B461-405E89D45633";
final String ruleId0 = "_1578BD9E-2BF9-4BFC-8956-1A736959C937";
final String ruleId3 = "_2545E1A8-93D3-4C8A-A0ED-8AD8B10A58F9";
JITDMNPayload jitdmnpayload = new JITDMNPayload(modelWithMultipleEvaluationHitIds, buildMultipleHitContext());
final String rule0 = "_E5C380DA-AF7B-4401-9804-C58296EC09DD";
final String rule1 = "_DFD65E8B-5648-4BFD-840F-8C76B8DDBD1A";
final String rule2 = "_E80EE7F7-1C0C-4050-B560-F33611F16B05";
String response = given().contentType(ContentType.JSON)
.body(jitdmnpayload)
.when().post("/jitdmn/dmnresult")
.then()
.statusCode(200)
.body(containsString("Risk Score"),
containsString("Loan Pre-Qualification"),
.body(containsString("Statistics"),
containsString(EVALUATION_HIT_IDS_FIELD_NAME),
containsString(elseElementId),
containsString(ruleId0),
containsString(ruleId3))
containsString(rule0),
containsString(rule1),
containsString(rule2))
.extract()
.asString();
JsonNode retrieved = MAPPER.readTree(response);
ArrayNode evaluationHitIdsNode = (ArrayNode) retrieved.get(EVALUATION_HIT_IDS_FIELD_NAME);
Assertions.assertThat(evaluationHitIdsNode).hasSize(3)
.anyMatch(node -> node.asText().equals(elseElementId))
.anyMatch(node -> node.asText().equals(ruleId0))
.anyMatch(node -> node.asText().equals(ruleId3));
ObjectNode evaluationHitIdsNode = (ObjectNode) retrieved.get(EVALUATION_HIT_IDS_FIELD_NAME);
Assertions.assertThat(evaluationHitIdsNode).hasSize(3);
final Map<String, Integer> expectedEvaluationHitIds = Map.of(rule0, 3, rule1, 2, rule2, 1);
evaluationHitIdsNode.fields().forEachRemaining(entry -> Assertions.assertThat(expectedEvaluationHitIds).containsEntry(entry.getKey(), entry.getValue().asInt()));
}

@Test
void testjitExplainabilityEndpoint() {
JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext());
JITDMNPayload jitdmnpayload = new JITDMNPayload(invalidModel, buildContext());
given()
.contentType(ContentType.JSON)
.body(jitdmnpayload)
Expand All @@ -142,18 +143,21 @@ void testjitdmnWithExtensionElements() {
.body(containsString("m"), containsString("n"), containsString("sum"));
}

private Map<String, Object> buildRiskScoreContext() {
Map<String, Object> context = new HashMap<>();
context.put("Credit Score", "Poor");
context.put("DTI", 33);
return context;
}

private Map<String, Object> buildContext() {
Map<String, Object> context = new HashMap<>();
context.put("FICO Score", 800);
context.put("DTI Ratio", .1);
context.put("PITI Ratio", .1);
return context;
}

private Map<String, Object> buildMultipleHitContext() {
final List<BigDecimal> numbers = new ArrayList<>();
numbers.add(BigDecimal.valueOf(10));
numbers.add(BigDecimal.valueOf(2));
numbers.add(BigDecimal.valueOf(1));
final Map<String, Object> context = new HashMap<>();
context.put("Numbers", numbers);
return context;
}
}

0 comments on commit 262af48

Please sign in to comment.