Skip to content

Commit

Permalink
[incubator-kie-issues#1591] Aggregate evaluationHitIds to Map<String,…
Browse files Browse the repository at this point in the history
… Integer>
  • Loading branch information
Gabriele-Cardosi committed Oct 30, 2024
1 parent 924e9f7 commit 4a6928e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -81,12 +80,12 @@ public JITDMNResult evaluate(Map<String, Object> context) {
DMNContext dmnContext =
new DynamicDMNContextBuilder(dmnRuntime.newContext(), dmnModel).populateContextWith(context);
DMNResult dmnResult = dmnRuntime.evaluateAll(dmnModel, dmnContext);
Optional<List<String>> evaluationHitIds = dmnRuntime.getListeners().stream()
Optional<Map<String, Integer>> evaluationHitIds = dmnRuntime.getListeners().stream()
.filter(JITDMNListener.class::isInstance)
.findFirst()
.map(JITDMNListener.class::cast)
.map(JITDMNListener::getEvaluationHitIds);
return new JITDMNResult(getNamespace(), getName(), dmnResult, evaluationHitIds.orElse(Collections.emptyList()));
return new JITDMNResult(getNamespace(), getName(), dmnResult, evaluationHitIds.orElse(Collections.emptyMap()));
}

public static DMNEvaluator fromMultiple(MultipleResourcesPayload payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
*/
package org.kie.kogito.jitexecutor.dmn;

import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
import java.util.Map;

import org.kie.dmn.api.core.event.AfterConditionalEvaluationEvent;
import org.kie.dmn.api.core.event.AfterEvaluateAllEvent;
Expand All @@ -36,14 +36,14 @@

public class JITDMNListener implements DMNRuntimeEventListener {

private final List<String> evaluationHitIds = new ArrayList<>();
private final Map<String, Integer> evaluationHitIds = new HashMap<>();

private static final Logger LOGGER = LoggerFactory.getLogger(JITDMNListener.class);

@Override
public void afterEvaluateDecisionTable(AfterEvaluateDecisionTableEvent event) {
logEvent(event);
evaluationHitIds.addAll(event.getSelectedIds());
event.getSelectedIds().forEach(s -> evaluationHitIds.compute(s, (k, v) -> v == null ? 1 : v + 1));
}

@Override
Expand Down Expand Up @@ -79,10 +79,10 @@ public void afterEvaluateAll(AfterEvaluateAllEvent event) {
@Override
public void afterConditionalEvaluation(AfterConditionalEvaluationEvent event) {
logEvent(event);
evaluationHitIds.add(event.getExecutedId());
evaluationHitIds.compute(event.getExecutedId(), (k, v) -> v == null ? 1 : v + 1);
}

public List<String> getEvaluationHitIds() {
public Map<String, Integer> getEvaluationHitIds() {
return evaluationHitIds;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ public class JITDMNResult implements Serializable,

private Map<String, JITDMNDecisionResult> decisionResults = new HashMap<>();

private List<String> evaluationHitIds;
private Map<String, Integer> evaluationHitIds;

public JITDMNResult() {
// Intentionally blank.
}

public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult) {
this(namespace, modelName, dmnResult, Collections.emptyList());
this(namespace, modelName, dmnResult, Collections.emptyMap());
}

public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult, List<String> evaluationHitIds) {
public JITDMNResult(String namespace, String modelName, org.kie.dmn.api.core.DMNResult dmnResult, Map<String, Integer> evaluationHitIds) {
this.namespace = namespace;
this.modelName = modelName;
this.setDmnContext(dmnResult.getContext().getAll());
Expand Down Expand Up @@ -110,11 +110,11 @@ public void setDecisionResults(List<? extends DMNDecisionResult> decisionResults
}
}

public List<String> getEvaluationHitIds() {
public Map<String, Integer> getEvaluationHitIds() {
return evaluationHitIds;
}

public void setEvaluationHitIds(List<String> evaluationHitIds) {
public void setEvaluationHitIds(Map<String, Integer> evaluationHitIds) {
this.evaluationHitIds = evaluationHitIds;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ void testDecisionTableModelEvaluation() throws IOException {
JITDMNResult dmnResult = jitdmnService.evaluateModel(decisionTableModel, context);

Assertions.assertEquals("LoanEligibility", dmnResult.getModelName());
Assertions.assertEquals("https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example", dmnResult.getNamespace());
Assertions.assertEquals("https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example",
dmnResult.getNamespace());
Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals("Yes", dmnResult.getDecisionResultByName("Eligibility").getResult());
}
Expand All @@ -101,12 +102,12 @@ void testEvaluationHitIds() throws IOException {
Assertions.assertEquals("DMN_A77074C1-21FE-4F7E-9753-F84661569AFC", dmnResult.getModelName());
Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals(BigDecimal.valueOf(50), dmnResult.getDecisionResultByName("Risk Score").getResult());
List<String> evaluationHitIds = dmnResult.getEvaluationHitIds();
Map<String, Integer> evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(elseElementId));
Assertions.assertTrue(evaluationHitIds.contains(ruleId0));
Assertions.assertTrue(evaluationHitIds.contains(ruleId3));
Assertions.assertTrue(evaluationHitIds.containsKey(elseElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId0));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId3));

context = new HashMap<>();
context.put("Credit Score", "Excellent");
Expand All @@ -118,9 +119,9 @@ void testEvaluationHitIds() throws IOException {
evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(thenElementId));
Assertions.assertTrue(evaluationHitIds.contains(ruleId1));
Assertions.assertTrue(evaluationHitIds.contains(ruleId4));
Assertions.assertTrue(evaluationHitIds.containsKey(thenElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId1));
Assertions.assertTrue(evaluationHitIds.containsKey(ruleId4));
}

@Test
Expand All @@ -141,12 +142,12 @@ void testConditionalWithNestedDecisionTableFromRiskScoreEvaluation() throws IOEx

Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals(BigDecimal.valueOf(50), dmnResult.getDecisionResultByName("Risk Score").getResult());
List<String> evaluationHitIds = dmnResult.getEvaluationHitIds();
Map<String, Integer> evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(thenElementId));
Assertions.assertTrue(evaluationHitIds.contains(thenRuleId0));
Assertions.assertTrue(evaluationHitIds.contains(thenRuleId4));
Assertions.assertTrue(evaluationHitIds.containsKey(thenElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(thenRuleId0));
Assertions.assertTrue(evaluationHitIds.containsKey(thenRuleId4));

context = new HashMap<>();
context.put("Credit Score", "Excellent");
Expand All @@ -159,9 +160,9 @@ void testConditionalWithNestedDecisionTableFromRiskScoreEvaluation() throws IOEx
evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertTrue(evaluationHitIds.contains(elseElementId));
Assertions.assertTrue(evaluationHitIds.contains(elseRuleId2));
Assertions.assertTrue(evaluationHitIds.contains(elseRuleId5));
Assertions.assertTrue(evaluationHitIds.containsKey(elseElementId));
Assertions.assertTrue(evaluationHitIds.containsKey(elseRuleId2));
Assertions.assertTrue(evaluationHitIds.containsKey(elseRuleId5));
}

@Test
Expand All @@ -185,12 +186,12 @@ void testMultipleHitRulesEvaluation() throws IOException {
expectedStatistcs.add(BigDecimal.valueOf(1));
Assertions.assertTrue(dmnResult.getMessages().isEmpty());
Assertions.assertEquals(expectedStatistcs, dmnResult.getDecisionResultByName("Statistics").getResult());
final List<String> evaluationHitIds = dmnResult.getEvaluationHitIds();
final Map<String, Integer> evaluationHitIds = dmnResult.getEvaluationHitIds();
Assertions.assertNotNull(evaluationHitIds);
Assertions.assertEquals(6, evaluationHitIds.size());
Assertions.assertEquals(3, evaluationHitIds.stream().filter(rule0::equals).count());
Assertions.assertEquals(2, evaluationHitIds.stream().filter(rule1::equals).count());
Assertions.assertEquals(1, evaluationHitIds.stream().filter(rule2::equals).count());
Assertions.assertEquals(3, evaluationHitIds.size());
Assertions.assertEquals(3, evaluationHitIds.get(rule0));
Assertions.assertEquals(2, evaluationHitIds.get(rule1));
Assertions.assertEquals(1, evaluationHitIds.get(rule2));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;

import io.quarkus.test.junit.QuarkusTest;
import io.restassured.http.ContentType;
Expand All @@ -43,7 +43,7 @@
@QuarkusTest
public class JITDMNResourceTest {

private static String model;
private static String invalidModel;
private static String modelWithExtensionElements;
private static String modelWithEvaluationHitIds;

Expand All @@ -57,14 +57,14 @@ public class JITDMNResourceTest {

@BeforeAll
public static void setup() throws IOException {
model = getModelFromIoUtils("invalid_models/DMNv1_x/test.dmn");
invalidModel = getModelFromIoUtils("invalid_models/DMNv1_x/test.dmn");
modelWithExtensionElements = getModelFromIoUtils("valid_models/DMNv1_x/testWithExtensionElements.dmn");
modelWithEvaluationHitIds = getModelFromIoUtils("valid_models/DMNv1_5/RiskScore_Simple.dmn");
}

@Test
void testjitEndpoint() {
JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext());
JITDMNPayload jitdmnpayload = new JITDMNPayload(invalidModel, buildContext());
given()
.contentType(ContentType.JSON)
.body(jitdmnpayload)
Expand All @@ -76,14 +76,15 @@ void testjitEndpoint() {

@Test
void testjitdmnResultEndpoint() {
JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext());
JITDMNPayload jitdmnpayload = new JITDMNPayload(modelWithEvaluationHitIds, buildContext());
given()
.contentType(ContentType.JSON)
.body(jitdmnpayload)
.when().post("/jitdmn/dmnresult")
.then()
.statusCode(200)
.body(containsString("Loan Approval"), containsString("Approved"), containsString("xls2dmn"));
.body(containsString("Risk Score"),
containsString("Loan Pre-Qualification"));
}

@Test
Expand All @@ -106,16 +107,18 @@ void testjitdmnResultEndpointWithEvaluationHitIds() throws JsonProcessingExcepti
.extract()
.asString();
JsonNode retrieved = MAPPER.readTree(response);
ArrayNode evaluationHitIdsNode = (ArrayNode) retrieved.get(EVALUATION_HIT_IDS_FIELD_NAME);
Assertions.assertThat(evaluationHitIdsNode).hasSize(3)
.anyMatch(node -> node.asText().equals(elseElementId))
.anyMatch(node -> node.asText().equals(ruleId0))
.anyMatch(node -> node.asText().equals(ruleId3));
ObjectNode evaluationHitIdsNode = (ObjectNode) retrieved.get(EVALUATION_HIT_IDS_FIELD_NAME);
Assertions.assertThat(evaluationHitIdsNode).hasSize(3);
final Map<String, Integer> expectedEvaluationHitIds = Map.of(elseElementId, 1, ruleId0, 1, ruleId3, 1);
evaluationHitIdsNode.fields().forEachRemaining(entry -> {
Assertions.assertThat(expectedEvaluationHitIds).containsKey(entry.getKey());
Assertions.assertThat(expectedEvaluationHitIds.get(entry.getKey())).isEqualTo(entry.getValue().asInt());
});
}

@Test
void testjitExplainabilityEndpoint() {
JITDMNPayload jitdmnpayload = new JITDMNPayload(model, buildContext());
JITDMNPayload jitdmnpayload = new JITDMNPayload(invalidModel, buildContext());
given()
.contentType(ContentType.JSON)
.body(jitdmnpayload)
Expand Down

0 comments on commit 4a6928e

Please sign in to comment.