Skip to content

Commit

Permalink
Merge pull request #400 from NannyML/mc_array_fix
Browse files Browse the repository at this point in the history
fix nan handling for MC CM CBPE realized performance
  • Loading branch information
nnansters authored Jun 24, 2024
2 parents 4f73494 + fe9f5d7 commit a4baf5e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
11 changes: 7 additions & 4 deletions nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3133,29 +3133,32 @@ def _multiclass_confusion_matrix_alert_thresholds(
return alert_thresholds

def _multi_class_confusion_matrix_realized_performance(self, data: pd.DataFrame) -> Union[np.ndarray, float]:
# Create appropriate nan array to return in case of error
num_classes = len(self.classes)
nan_array = np.full(shape=(num_classes, num_classes), fill_value=np.nan)
try:
_list_missing([self.y_true, self.y_pred], data)
except InvalidArgumentsException as ex:
if "missing required columns" in str(ex):
self._logger.debug(str(ex))
return np.NaN
return nan_array
else:
raise ex

data, empty = common_nan_removal(data, [self.y_true, self.y_pred])
if empty:
warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.")
return np.NaN
return nan_array

y_true = data[self.y_true]
if y_true.nunique() <= 1:
warnings.warn(f"Too few unique values present in 'y_true', returning NaN as realized {self.display_name}.")
return np.NaN
return nan_array
if data[self.y_pred].nunique() <= 1:
warnings.warn(
f"Too few unique values present in 'y_pred', returning NaN as realized {self.display_name} score."
)
return np.NaN
return nan_array

cm = confusion_matrix(
data[self.y_true], data[self.y_pred], labels=self.classes, normalize=self.normalize_confusion_matrix
Expand Down
118 changes: 118 additions & 0 deletions tests/performance_estimation/CBPE/test_cbpe_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests."""

import pandas as pd
import numpy as np
import pytest

from nannyml.chunk import DefaultChunker, SizeBasedChunker
Expand Down Expand Up @@ -3462,3 +3463,120 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits
f'{metric.display_name} lower threshold value -1 overridden by '
f'lower threshold value limit {metric.lower_threshold_value_limit}' in caplog.messages
)


@pytest.mark.parametrize(
'calculator_opts, realized',
[
(
{'chunk_size': 20000},
pd.DataFrame(
{
'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'],
'realized_roc_auc': [0.909805, 0.840071, np.nan],
'realized_f1': [0.759170, 0.658896, np.nan],
'realized_precision': [0.759265, 0.660188, np.nan],
'realized_recall': [0.759149, 0.658760, np.nan],
'realized_specificity': [0.879632, 0.829581, np.nan],
'realized_accuracy': [0.75925, 0.65950, np.nan],
'realized_true_highstreet_card_pred_highstreet_card': [
4912.0,
4702.0,
np.nan,
],
'realized_true_highstreet_card_pred_prepaid_card': [
870.0,
1083.0,
np.nan,
],
'realized_true_highstreet_card_pred_upmarket_card': [
799.0,
1009.0,
np.nan,
],
'realized_true_prepaid_card_pred_highstreet_card': [
846.0,
1367.0,
np.nan,
],
'realized_true_prepaid_card_pred_prepaid_card': [
5203.0,
3974.0,
np.nan,
],
'realized_true_prepaid_card_pred_upmarket_card': [
690.0,
1080.0,
np.nan,
],
'realized_true_upmarket_card_pred_highstreet_card': [
837.0,
1282.0,
np.nan,
],
'realized_true_upmarket_card_pred_prepaid_card': [
773.0,
989.0,
np.nan,
],
'realized_true_upmarket_card_pred_upmarket_card': [
5070.0,
4514.0,
np.nan,
],
}
),
),
]
)
def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realized): # noqa: D103
"""Test Nan Handling of CM MC metric."""
reference, analysis, targets = load_synthetic_multiclass_classification_dataset()
analysis = analysis.merge(targets, left_index=True, right_index=True)
analysis.y_true[-20_000:] = np.nan
cbpe = CBPE(
y_pred_proba={
'upmarket_card': 'y_pred_proba_upmarket_card',
'highstreet_card': 'y_pred_proba_highstreet_card',
'prepaid_card': 'y_pred_proba_prepaid_card',
},
y_pred='y_pred',
y_true='y_true',
problem_type='classification_multiclass',
metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'],
**calculator_opts,
).fit(reference)
result = cbpe.estimate(analysis)
column_names = [(m.name, 'realized') for m in result.metrics]
column_names = [c for c in column_names if c[0] != 'confusion_matrix']
column_names += [
('true_highstreet_card_pred_highstreet_card', 'realized'),
('true_highstreet_card_pred_prepaid_card', 'realized'),
('true_highstreet_card_pred_upmarket_card', 'realized'),
('true_prepaid_card_pred_highstreet_card', 'realized'),
('true_prepaid_card_pred_prepaid_card', 'realized'),
('true_prepaid_card_pred_upmarket_card', 'realized'),
('true_upmarket_card_pred_highstreet_card', 'realized'),
('true_upmarket_card_pred_prepaid_card', 'realized'),
('true_upmarket_card_pred_upmarket_card', 'realized'),
]
sut = result.filter(period='analysis').to_df()[[('chunk', 'key')] + column_names]
sut.columns = [
'key',
'realized_roc_auc',
'realized_f1',
'realized_precision',
'realized_recall',
'realized_specificity',
'realized_accuracy',
'realized_true_highstreet_card_pred_highstreet_card',
'realized_true_highstreet_card_pred_prepaid_card',
'realized_true_highstreet_card_pred_upmarket_card',
'realized_true_prepaid_card_pred_highstreet_card',
'realized_true_prepaid_card_pred_prepaid_card',
'realized_true_prepaid_card_pred_upmarket_card',
'realized_true_upmarket_card_pred_highstreet_card',
'realized_true_upmarket_card_pred_prepaid_card',
'realized_true_upmarket_card_pred_upmarket_card',
]
pd.testing.assert_frame_equal(realized, sut)

0 comments on commit a4baf5e

Please sign in to comment.