Skip to content

Commit

Permalink
expand test to all cm metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
nikml committed Jun 20, 2024
1 parent 3503247 commit fe9f5d7
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tests/performance_estimation/CBPE/test_cbpe_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3473,6 +3473,12 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits
pd.DataFrame(
{
'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'],
'realized_roc_auc': [0.909805, 0.840071, np.nan],
'realized_f1': [0.759170, 0.658896, np.nan],
'realized_precision': [0.759265, 0.660188, np.nan],
'realized_recall': [0.759149, 0.658760, np.nan],
'realized_specificity': [0.879632, 0.829581, np.nan],
'realized_accuracy': [0.75925, 0.65950, np.nan],
'realized_true_highstreet_card_pred_highstreet_card': [
4912.0,
4702.0,
Expand Down Expand Up @@ -3523,7 +3529,7 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits
),
]
)
def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, estimated, realized): # noqa: D103
def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realized): # noqa: D103
"""Test Nan Handling of CM MC metric."""
reference, analysis, targets = load_synthetic_multiclass_classification_dataset()
analysis = analysis.merge(targets, left_index=True, right_index=True)
Expand All @@ -3537,11 +3543,13 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, estima
y_pred='y_pred',
y_true='y_true',
problem_type='classification_multiclass',
metrics=['confusion_matrix'],
metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'],
**calculator_opts,
).fit(reference)
result = cbpe.estimate(analysis)
column_names = [
column_names = [(m.name, 'realized') for m in result.metrics]
column_names = [c for c in column_names if c[0] != 'confusion_matrix']
column_names += [
('true_highstreet_card_pred_highstreet_card', 'realized'),
('true_highstreet_card_pred_prepaid_card', 'realized'),
('true_highstreet_card_pred_upmarket_card', 'realized'),
Expand All @@ -3555,6 +3563,12 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, estima
sut = result.filter(period='analysis').to_df()[[('chunk', 'key')] + column_names]
sut.columns = [
'key',
'realized_roc_auc',
'realized_f1',
'realized_precision',
'realized_recall',
'realized_specificity',
'realized_accuracy',
'realized_true_highstreet_card_pred_highstreet_card',
'realized_true_highstreet_card_pred_prepaid_card',
'realized_true_highstreet_card_pred_upmarket_card',
Expand Down

0 comments on commit fe9f5d7

Please sign in to comment.