Skip to content

Commit

Permalink
Calibrator factory decorator (#341)
Browse files Browse the repository at this point in the history
* Refactor CalibratorFactory to use decorator

* formatting, linting and whatnot

* Fix broken tests
  • Loading branch information
nnansters authored Nov 20, 2023
1 parent 0478b24 commit ed8cb08
Show file tree
Hide file tree
Showing 35 changed files with 10,986 additions and 10,901 deletions.
4,176 changes: 2,088 additions & 2,088 deletions docs/_static/butterfly-scatterplot.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,172 changes: 586 additions & 586 deletions docs/_static/example_california_latitude_longitude_scatter.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,378 changes: 689 additions & 689 deletions docs/_static/example_california_performance_estimation_tmp.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,388 changes: 694 additions & 694 deletions docs/_static/example_green_taxi_feature_importance.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
576 changes: 288 additions & 288 deletions docs/_static/example_green_taxi_tip_amount_boxplot.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
754 changes: 377 additions & 377 deletions docs/_static/example_green_taxi_tip_amount_distribution.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
942 changes: 471 additions & 471 deletions docs/_static/how-it-works-dle-data.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,438 changes: 719 additions & 719 deletions docs/_static/how-it-works-dle-regression-PI.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,646 changes: 823 additions & 823 deletions docs/_static/how-it-works-dle-regression-abs-errors-hist.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,474 changes: 737 additions & 737 deletions docs/_static/how-it-works-dle-regression-errors-hist.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,184 changes: 592 additions & 592 deletions docs/_static/how-it-works-dle-regression.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,330 changes: 665 additions & 665 deletions docs/_static/how-it-works/chunks_stability_of_accuracy.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,508 changes: 754 additions & 754 deletions docs/_static/how-it-works/ranking-abs-perf-features-compare.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,366 changes: 683 additions & 683 deletions docs/_static/how-it-works/ranking-abs-perf.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ Examples
:maxdepth: 2

examples/california_housing
examples/green_taxi
examples/green_taxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ For more information about estimating these metrics, refer to the :ref:`multicla

We also support the following *complex* metric for multiclass classification performance calculation:

* **confusion_matrix**
* **confusion_matrix**

For more information about estimating this metrics, refer to the :ref:`multiclass-confusion-matrix-estimation` section.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ The results can be plotted for visual inspection. Our plot contains several key
* *The purple step plot* shows the performance in each chunk of the analysis period. Thick squared point
markers indicate the middle of these chunks.

* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate
* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate
the middle of these chunks.

* *The gray vertical line* splits the reference and analysis periods.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ what feature changes may be contributing to any performance changes. We can also
and :ref:`compare it with the estimated results<compare_estimated_and_realized_performance>`.

It is also wise to check whether the model's performance is satisfactory
according to business requirements. This is an ad-hoc investigation that is not covered by NannyML.
according to business requirements. This is an ad-hoc investigation that is not covered by NannyML.
1 change: 1 addition & 0 deletions nannyml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def _column_is_categorical(column: pd.Series) -> bool:
def _remove_nans(data: pd.Series) -> pd.Series:
...


@overload
def _remove_nans(data: pd.DataFrame, columns: Optional[Iterable[Union[str, Iterable[str]]]]) -> pd.DataFrame:
...
Expand Down
42 changes: 27 additions & 15 deletions nannyml/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

"""Calibrating model scores into probabilities."""
import abc
from typing import Any, Callable, List, Optional, Tuple
import warnings
from typing import Any, Callable, Dict, List, Tuple, Type

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -45,10 +46,10 @@ def calibrate(self, y_pred_proba: np.ndarray):
class CalibratorFactory:
"""Factory class to aid in construction of Calibrators."""

_calibrators = {'isotonic': lambda args: IsotonicCalibrator()}
_registry: Dict[str, Type[Calibrator]] = {}

@classmethod
def register_calibrator(cls, key: str, create_calibrator: Callable):
def register_calibrator(cls, key: str, calibrator: Type[Calibrator]):
"""Registers a new calibrator to the index.
This index associates a certain key with a function that can be used to construct a new Calibrator instance.
Expand All @@ -58,17 +59,28 @@ def register_calibrator(cls, key: str, create_calibrator: Callable):
key: str
The key used to retrieve a Calibrator. When providing a key that is already in the index, the value
will be overwritten.
create_calibrator: Callable
calibrator: Type[Calibrator]
A function that - given a ``**kwargs`` argument - create a new instance of a Calibrator subclass.
Examples
--------
>>> CalibratorFactory.register_calibrator('isotonic', lambda kwargs: IsotonicCalibrator())
>>> CalibratorFactory.register_calibrator('isotonic', IsotonicCalibrator)
"""
cls._calibrators[key] = create_calibrator
cls._registry[key] = calibrator

@classmethod
def create(cls, key: Optional[str], **kwargs):
def register(cls, key: str) -> Callable:
def inner_wrapper(wrapped_class: Type[Calibrator]) -> Type[Calibrator]:
if key in cls._registry:
warnings.warn(f"re-registering calibrator with key '{key}'")

cls._registry[key] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create(cls, key: str = 'isotonic', **kwargs):
"""Creates a new Calibrator given a key value and optional keyword args.
If the provided key equals ``None``, then a new instance of the default Calibrator (IsotonicCalibrator)
Expand All @@ -78,7 +90,7 @@ def create(cls, key: Optional[str], **kwargs):
Parameters
----------
key : str
key : str, default='isotonic'
The key used to retrieve a Calibrator. When providing a key that is already in the index, the value
will be overwritten.
kwargs : dict
Expand All @@ -94,18 +106,18 @@ def create(cls, key: Optional[str], **kwargs):
--------
>>> calibrator = CalibratorFactory.create('isotonic', kwargs={'foo': 'bar'})
"""
default = IsotonicCalibrator()
if key is None:
return default

if key not in cls._calibrators:
if key not in cls._registry:
raise InvalidArgumentsException(
f"calibrator {key} unknown. " f"Please provide one of the following: {cls._calibrators.keys()}"
f"calibrator '{key}' unknown. " f"Please provide one of the following: {cls._registry.keys()}"
)

return cls._calibrators.get(key, default)
calibrator_class = cls._registry.get(key)
assert calibrator_class

return calibrator_class(**kwargs)


@CalibratorFactory.register('isotonic')
class IsotonicCalibrator(Calibrator):
"""Calibrates using IsotonicRegression model."""

Expand Down
1 change: 0 additions & 1 deletion nannyml/drift/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def _validate_drift_result(rankable_result: RankableResult):
raise InvalidArgumentsException('rankable_result contains no data to use for ranking')

if isinstance(rankable_result, UnivariateResults):

if len(rankable_result.categorical_method_names) > 1:
raise InvalidArgumentsException(
f"Only one categorical drift method should be present in the univariate results."
Expand Down
2 changes: 1 addition & 1 deletion nannyml/drift/univariate/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from scipy.stats import chi2_contingency, ks_2samp, wasserstein_distance

from nannyml._typing import Self
from nannyml.base import _remove_nans, _column_is_categorical
from nannyml.base import _column_is_categorical, _remove_nans
from nannyml.chunk import Chunker
from nannyml.exceptions import InvalidArgumentsException, NotFittedException
from nannyml.thresholds import Threshold, calculate_threshold_values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score

from nannyml._typing import ProblemType
from nannyml.base import _remove_nans, _list_missing
from nannyml.base import _list_missing, _remove_nans
from nannyml.chunk import Chunk, Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory
Expand Down Expand Up @@ -544,17 +544,15 @@ def _calculate(self, data: pd.DataFrame):
tn_value = self.business_value_matrix[0, 0]
fp_value = self.business_value_matrix[0, 1]
fn_value = self.business_value_matrix[1, 0]
bv_array = np.array(
[[tn_value,fp_value], [fn_value,tp_value]]
)
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])

cm = confusion_matrix(y_true, y_pred)
if self.normalize_business_value == 'per_prediction':
with np.errstate(all="ignore"):
cm = cm / cm.sum(axis=0, keepdims=True)
cm = np.nan_to_num(cm)

return (bv_array*cm).sum()
return (bv_array * cm).sum()

def _sampling_error(self, data: pd.DataFrame) -> float:
return business_value_sampling_error(self._sampling_error_components, data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sklearn.preprocessing import LabelBinarizer, label_binarize

from nannyml._typing import ProblemType, class_labels, model_output_column_names
from nannyml.base import _remove_nans, _list_missing
from nannyml.base import _list_missing, _remove_nans
from nannyml.chunk import Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory
Expand All @@ -35,14 +35,14 @@
auroc_sampling_error_components,
f1_sampling_error,
f1_sampling_error_components,
multiclass_confusion_matrix_sampling_error,
multiclass_confusion_matrix_sampling_error_components,
precision_sampling_error,
precision_sampling_error_components,
recall_sampling_error,
recall_sampling_error_components,
specificity_sampling_error,
specificity_sampling_error_components,
multiclass_confusion_matrix_sampling_error,
multiclass_confusion_matrix_sampling_error_components,
)
from nannyml.thresholds import Threshold, calculate_threshold_values

Expand Down Expand Up @@ -588,7 +588,6 @@ def __init__(
normalize_confusion_matrix: Optional[str] = None,
**kwargs,
):

"""Creates a new confusion matrix instance."""
super().__init__(
name='confusion_matrix',
Expand All @@ -607,7 +606,6 @@ def __str__(self):
return "confusion_matrix"

def fit(self, reference_data: pd.DataFrame, chunker: Chunker):

# _fit
# realized perf on chunks
# set thresholds
Expand Down Expand Up @@ -700,7 +698,6 @@ def _calculate(self, data: pd.DataFrame) -> Union[np.ndarray, float]:
return cm

def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict[str, Union[float, bool]]:

if self.classes is None:
raise ValueError("classes must be set before calling this method")

Expand All @@ -714,7 +711,6 @@ def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict[str, Union[float, b

for true_class in self.classes:
for pred_class in self.classes:

column_name = f'true_{true_class}_pred_{pred_class}'

chunk_record[f"{column_name}_sampling_error"] = sampling_errors[
Expand Down
2 changes: 1 addition & 1 deletion nannyml/performance_calculation/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from nannyml._typing import ProblemType
from nannyml.base import _remove_nans, _list_missing, _raise_exception_for_negative_values
from nannyml.base import _list_missing, _raise_exception_for_negative_values, _remove_nans
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory
from nannyml.sampling_error.regression import (
mae_sampling_error,
Expand Down
2 changes: 1 addition & 1 deletion nannyml/performance_estimation/confidence_based/cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
chunk_number: Optional[int] = None,
chunk_period: Optional[str] = None,
chunker: Optional[Chunker] = None,
calibration: Optional[str] = None,
calibration: str = 'isotonic',
calibrator: Optional[Calibrator] = None,
thresholds: Optional[Dict[str, Threshold]] = None,
normalize_confusion_matrix: Optional[str] = None,
Expand Down
24 changes: 5 additions & 19 deletions nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,16 +1573,14 @@ def _realized_performance(self, data: pd.DataFrame) -> float:
tn_value = self.business_value_matrix[0, 0]
fp_value = self.business_value_matrix[0, 1]
fn_value = self.business_value_matrix[1, 0]
bv_array = np.array(
[[tn_value,fp_value], [fn_value,tp_value]]
)
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])

cm = confusion_matrix(y_true, y_pred)
if self.normalize_business_value == 'per_prediction':
with np.errstate(all="ignore"):
cm = cm / cm.sum(axis=0, keepdims=True)
cm = np.nan_to_num(cm)
return (bv_array*cm).sum()
return (bv_array * cm).sum()

def _estimate(self, chunk_data: pd.DataFrame) -> float:
y_pred_proba = chunk_data[self.y_pred_proba]
Expand Down Expand Up @@ -1630,9 +1628,7 @@ def estimate_business_value(
est_tp_ratio = np.mean(np.where(y_pred == 1, y_pred_proba, 0))
est_fp_ratio = np.mean(np.where(y_pred == 1, 1 - y_pred_proba, 0))
est_fn_ratio = np.mean(np.where(y_pred == 0, y_pred_proba, 0))
cm = np.array(
[[est_tn_ratio, est_fp_ratio], [est_fn_ratio, est_tp_ratio]]
)*len(y_pred)
cm = np.array([[est_tn_ratio, est_fp_ratio], [est_fn_ratio, est_tp_ratio]]) * len(y_pred)
if normalize_business_value == 'per_prediction':
with np.errstate(all="ignore"):
cm = cm / cm.sum(axis=0, keepdims=True)
Expand All @@ -1642,11 +1638,9 @@ def estimate_business_value(
tn_value = business_value_matrix[0, 0]
fp_value = business_value_matrix[0, 1]
fn_value = business_value_matrix[1, 0]
bv_array = np.array(
[[tn_value,fp_value], [fn_value,tp_value]]
)
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])

return (bv_array*cm).sum()
return (bv_array * cm).sum()


def _get_binarized_multiclass_predictions(data: pd.DataFrame, y_pred: str, y_pred_proba: ModelOutputsType):
Expand Down Expand Up @@ -2108,7 +2102,6 @@ def __init__(
normalize_confusion_matrix: Optional[str] = None,
**kwargs,
):

if isinstance(y_pred_proba, str):
raise ValueError(
"y_pred_proba must be a dictionary with class labels as keys and pred_proba column names as values"
Expand Down Expand Up @@ -2167,7 +2160,6 @@ def fit(self, reference_data: pd.DataFrame): # override the superclass fit meth
return

def _fit(self, reference_data: pd.DataFrame):

self._confusion_matrix_sampling_error_components = mse.multiclass_confusion_matrix_sampling_error_components(
y_true_reference=reference_data[self.y_true],
y_pred_reference=reference_data[self.y_pred],
Expand All @@ -2177,7 +2169,6 @@ def _fit(self, reference_data: pd.DataFrame):
def _multiclass_confusion_matrix_alert_thresholds(
self, reference_chunks: List[Chunk]
) -> Dict[str, Tuple[Optional[float], Optional[float]]]:

realized_chunk_performance = np.asarray(
[self._multi_class_confusion_matrix_realized_performance(chunk.data) for chunk in reference_chunks]
)
Expand Down Expand Up @@ -2224,22 +2215,19 @@ def _multiclass_confusion_matrix_confidence_deviations(
self,
reference_chunks: List[Chunk],
) -> Dict[str, float]:

confidence_deviations = {}

num_classes = len(self.classes)

for i in range(num_classes):
for j in range(num_classes):

confidence_deviations[f'true_{self.classes[i]}_pred_{self.classes[j]}'] = np.std(
[self._get_multiclass_confusion_matrix_estimate(chunk.data)[i, j] for chunk in reference_chunks]
)

return confidence_deviations

def _get_multiclass_confusion_matrix_estimate(self, chunk_data: pd.DataFrame) -> np.ndarray:

if isinstance(self.y_pred_proba, str):
raise ValueError(
"y_pred_proba must be a dictionary with class labels as keys and pred_proba column names as values"
Expand Down Expand Up @@ -2282,7 +2270,6 @@ def _get_multiclass_confusion_matrix_estimate(self, chunk_data: pd.DataFrame) ->
return normalized_est_confusion_matrix

def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict:

chunk_record = {}

estimated_cm = self._get_multiclass_confusion_matrix_estimate(chunk_data)
Expand All @@ -2295,7 +2282,6 @@ def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict:

for true_class in self.classes:
for pred_class in self.classes:

chunk_record[f'estimated_true_{true_class}_pred_{pred_class}'] = estimated_cm[
self.classes.index(true_class), self.classes.index(pred_class)
]
Expand Down
3 changes: 0 additions & 3 deletions nannyml/plots/blueprints/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def _plot_compare_step_to_step( # noqa: C901
metric_2_color=Colors.BLUE_SKY_CRAYOLA,
**kwargs,
) -> Figure:

_metric_1_kwargs = {k.replace('metric_1_', ''): v for k, v in kwargs.items() if k.startswith('metric_1_')}
_metric_2_kwargs = {k.replace('metric_2_', ''): v for k, v in kwargs.items() if k.startswith('metric_2_')}

Expand Down Expand Up @@ -442,7 +441,6 @@ def _plot_compare_step_to_step( # noqa: C901
# endregion

if has_analysis_results:

# region analysis metric 1

_hover = hover or Hover(
Expand Down Expand Up @@ -693,7 +691,6 @@ def _is_estimated_result(result: Result) -> bool:

class ResultComparison:
def __init__(self, result: Result, other: Result, plot_kwargs: Dict[str, Any], title: Optional[str] = None):

if len(result.keys()) != 1 or len(result.keys()) != 1:
raise InvalidArgumentsException(
f"you're comparing {len(result.keys())} metrics to {len(result.keys())} "
Expand Down
Loading

0 comments on commit ed8cb08

Please sign in to comment.