Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow MapieRegressor to use any split strategy #386

Merged
merged 15 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

##### (##########)
------------------
* Allow to use more split methods for MapieRegressor (ShuffleSplit, PredefinedSplit).
* Integrate ConformityScore into MapieTimeSeriesRegressor.
* Add new checks for metrics calculations.
* Fix reference for residual normalised score in documentation.
Expand Down
17 changes: 10 additions & 7 deletions mapie/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import RegressorMixin, clone
from sklearn.model_selection import BaseCrossValidator, ShuffleSplit
from sklearn.model_selection import BaseCrossValidator
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import (_num_samples, check_is_fitted)

from mapie._typing import ArrayLike, NDArray
from mapie.aggregation_functions import aggregate_all, phi2D
from mapie.utils import (check_nan_in_aposteriori_prediction,
from mapie.utils import (check_nan_in_aposteriori_prediction, check_no_agg_cv,
fit_estimator)
from mapie.estimator.interface import EnsembleEstimator

Expand Down Expand Up @@ -152,6 +152,7 @@ class EnsembleRegressor(EnsembleEstimator):
"single_estimator_",
"estimators_",
"k_",
"use_split_method",
thibaultcordier marked this conversation as resolved.
Show resolved Hide resolved
]

def __init__(
Expand Down Expand Up @@ -278,10 +279,10 @@ def _aggregate_with_mask(
ArrayLike of shape (n_samples_test,)
Array of aggregated predictions for each testing sample.
"""
if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or self.use_split_method:
raise ValueError(
"There should not be aggregation of predictions "
f"if cv is in '{self.no_agg_cv_}' "
f"if cv is in '{self.no_agg_cv_}', if cv >=2 "
thibaultcordier marked this conversation as resolved.
Show resolved Hide resolved
f"or if method is in '{self.no_agg_methods_}'."
)
elif self.agg_function == "median":
Expand Down Expand Up @@ -406,6 +407,7 @@ def fit(
estimators_: List[RegressorMixin] = []
full_indexes = np.arange(_num_samples(X))
cv = self.cv
self.use_split_method = check_no_agg_cv(X, self.cv, self.no_agg_cv_)
estimator = self.estimator
n_samples = _num_samples(y)

Expand Down Expand Up @@ -434,8 +436,9 @@ def fit(
)
for train_index, _ in cv.split(X)
)
if isinstance(cv, ShuffleSplit):
single_estimator_ = estimators_[0]
# In split-CP, we keep only the model fitted on train dataset
if self.use_split_method:
single_estimator_ = estimators_[0]

self.single_estimator_ = single_estimator_
self.estimators_ = estimators_
Expand Down Expand Up @@ -487,7 +490,7 @@ def predict(
if not return_multi_pred and not ensemble:
return y_pred

if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or self.use_split_method:
y_pred_multi_low = y_pred[:, np.newaxis]
y_pred_multi_up = y_pred[:, np.newaxis]
else:
Expand Down
4 changes: 3 additions & 1 deletion mapie/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _check_agg_function(
"You need to specify an aggregation function when "
f"cv's type is in {self.cv_need_agg_function_}."
)
elif (agg_function is not None) or (self.cv in self.no_agg_cv_):
elif agg_function is not None:
thibaultcordier marked this conversation as resolved.
Show resolved Hide resolved
return agg_function
else:
return "mean"
Expand Down Expand Up @@ -507,6 +507,8 @@ def fit(
)
# Fit the prediction function
self.estimator_ = self.estimator_.fit(X, y, sample_weight)

# Predict on calibration data
y_pred = self.estimator_.predict_calib(X)

# Compute the conformity scores (manage jk-ab case)
Expand Down
5 changes: 3 additions & 2 deletions mapie/regression/time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,9 @@ def predict(
self.lower_quantiles_ = lower_quantiles
self.higher_quantiles_ = higher_quantiles

if self.method in self.no_agg_methods_ \
or self.cv in self.no_agg_cv_:
if self.method in self.no_agg_methods_ or (
thibaultcordier marked this conversation as resolved.
Show resolved Hide resolved
self.estimator_.use_split_method
):
y_pred_low = y_pred[:, np.newaxis] + lower_quantiles
y_pred_up = y_pred[:, np.newaxis] + higher_quantiles
else:
Expand Down
9 changes: 6 additions & 3 deletions mapie/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from sklearn.dummy import DummyRegressor
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import (KFold, LeaveOneOut, ShuffleSplit,
train_test_split)
from sklearn.model_selection import (KFold, LeaveOneOut, PredefinedSplit,
ShuffleSplit, train_test_split)
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -211,7 +211,9 @@ def test_valid_agg_function(agg_function: str) -> None:

@pytest.mark.parametrize(
"cv", [None, -1, 2, KFold(), LeaveOneOut(),
ShuffleSplit(n_splits=1), "prefit", "split"]
ShuffleSplit(n_splits=1),
PredefinedSplit(test_fold=[-1]*3+[0]*3),
"prefit", "split"]
)
def test_valid_cv(cv: Any) -> None:
"""Test that valid cv raise no errors."""
Expand Down Expand Up @@ -526,6 +528,7 @@ def test_aggregate_with_mask_with_invalid_agg_function() -> None:
0.20,
False
)
ens_reg.use_split_method = False
with pytest.raises(
ValueError,
match=r".*The value of self.agg_function is not correct*",
Expand Down
25 changes: 20 additions & 5 deletions mapie/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Tuple

import numpy as np
import pytest
from numpy.random import RandomState
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import BaseCrossValidator
from sklearn.model_selection import BaseCrossValidator, KFold, ShuffleSplit
from sklearn.utils.validation import check_is_fitted

from mapie._typing import ArrayLike, NDArray
Expand All @@ -16,9 +16,10 @@
check_array_nan, check_array_inf, check_arrays_length,
check_binary_zero_one, check_cv,
check_lower_upper_bounds, check_n_features_in,
check_n_jobs, check_null_weight, check_number_bins,
check_split_strategy, check_verbose,
compute_quantiles, fit_estimator, get_binning_groups)
check_n_jobs, check_no_agg_cv, check_null_weight,
check_number_bins, check_split_strategy,
check_verbose, compute_quantiles, fit_estimator,
get_binning_groups)

X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1)
y_toy = np.array([5, 7, 9, 11, 13, 15])
Expand Down Expand Up @@ -474,3 +475,17 @@ def test_check_cv_same_split_no_random_state(cv: BaseCrossValidator) -> None:

for i in range(cv.get_n_splits()):
np.testing.assert_allclose(train_indices_1[i], train_indices_2[i])


@pytest.mark.parametrize(
"cv_result", [
(1, True), (2, False),
("split", True), (KFold(5), False),
(ShuffleSplit(1), True),
(ShuffleSplit(2), False)
]
)
def test_check_no_agg_cv(cv_result: Tuple) -> None:
array = ["prefit", "split"]
cv, result = cv_result
np.testing.assert_almost_equal(check_no_agg_cv(X_toy, cv, array), result)
48 changes: 42 additions & 6 deletions mapie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,22 @@ def fit_estimator(


def check_cv(
cv: Optional[Union[int, str, BaseCrossValidator]] = None,
cv: Optional[Union[int, str, BaseCrossValidator, BaseShuffleSplit]] = None,
test_size: Optional[Union[int, float]] = None,
random_state: Optional[Union[int, np.random.RandomState]] = None,
) -> Union[str, BaseCrossValidator]:
) -> Union[str, BaseCrossValidator, BaseShuffleSplit]:
"""
Check if cross-validator is
``None``, ``int``, ``"prefit"``, ``"split"``or ``BaseCrossValidator``.
``None``, ``int``, ``"prefit"``, ``"split"``, ``BaseCrossValidator`` or
``BaseShuffleSplit``.
Return a ``LeaveOneOut`` instance if integer equal to -1.
Return a ``KFold`` instance if integer superior or equal to 2.
Return a ``KFold`` instance if ``None``.
Else raise error.

Parameters
----------
cv: Optional[Union[int, str, BaseCrossValidator]], optional
cv: Optional[Union[int, str, BaseCrossValidator, BaseShuffleSplit]]
Cross-validator to check, by default ``None``.

test_size: Optional[Union[int, float]]
Expand All @@ -163,8 +164,8 @@ def check_cv(

Returns
-------
Optional[Union[float, str]]
'prefit' or None.
Union[str, BaseCrossValidator, BaseShuffleSplit]
The cast `cv` parameter.

Raises
------
Expand Down Expand Up @@ -208,6 +209,41 @@ def check_cv(
)


def check_no_agg_cv(
X: ArrayLike,
cv: Union[int, str, BaseCrossValidator, BaseShuffleSplit],
no_agg_cv_array: list,
) -> bool:
"""
Check if cross-validator is ``"prefit"``, ``"split"`` or any split
equivalent `BaseCrossValidator` or `BaseShuffleSplit`.

Parameters
----------
X: ArrayLike of shape (n_samples, n_features)
Training data.

cv: Union[int, str, BaseCrossValidator, BaseShuffleSplit]
Cross-validator to check.

no_agg_cv_array: list
List of all non-aggregated cv methods.

Returns
-------
bool
True if `cv` is a split equivalent / non-aggregated cv method.
"""
if isinstance(cv, str):
return cv in no_agg_cv_array
elif isinstance(cv, int):
return cv == 1
if hasattr(cv, "get_n_splits"):
vincentblot28 marked this conversation as resolved.
Show resolved Hide resolved
return cv.get_n_splits(X) == 1
else:
vincentblot28 marked this conversation as resolved.
Show resolved Hide resolved
return False


def check_alpha(
alpha: Optional[Union[float, Iterable[float]]] = None
) -> Optional[ArrayLike]:
Expand Down
Loading