diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 50e9bef4f55f1..3827359b9162e 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -104,14 +104,18 @@ def _get_n_samples_bootstrap(n_samples, max_samples): """ Get the number of samples in a bootstrap sample. + The expected total number of unique samples in a bootstrap sample is + required to be at most ``n_samples - 1``. + This is equivalent to the expected number of out-of-bag samples being at + least 1. + Parameters ---------- n_samples : int Number of samples in the dataset. max_samples : int or float The maximum number of samples to draw from the total available: - - if float, this indicates a fraction of the total and should be - the interval `(0.0, 1.0]`; + - if float, this indicates a fraction of the total; - if int, this indicates the exact number of samples; - if None, this indicates the total number of samples. @@ -124,12 +128,21 @@ def _get_n_samples_bootstrap(n_samples, max_samples): return n_samples if isinstance(max_samples, Integral): - if max_samples > n_samples: - msg = "`max_samples` must be <= n_samples={} but got value {}" - raise ValueError(msg.format(n_samples, max_samples)) + expected_oob_samples = (1 - np.exp(-max_samples / n_samples)) * n_samples + if expected_oob_samples >= n_samples - 1: + raise ValueError( + "The expected number of unique samples in the bootstrap sample" + f" must be at most {n_samples - 1}. It is: {expected_oob_samples}" + ) return max_samples if isinstance(max_samples, Real): + expected_oob_samples = (1 - np.exp(-max_samples)) * n_samples + if expected_oob_samples >= n_samples - 1: + raise ValueError( + "The expected number of unique samples in the bootstrap sample" + f" must be at most {n_samples - 1}. It is: {expected_oob_samples}" + ) return max(round(n_samples * max_samples), 1) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index a51d240c87d4e..7914823d48ccf 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1660,7 +1660,10 @@ def test_max_samples_bootstrap(name): def test_large_max_samples_exception(name): # Check invalid `max_samples` est = FOREST_CLASSIFIERS_REGRESSORS[name](bootstrap=True, max_samples=int(1e9)) - match = "`max_samples` must be <= n_samples=6 but got value 1000000000" + # TODO: remove the following line when the issue is fixed + # https://github.com/scikit-learn/scikit-learn/issues/28507 + # match = "`max_samples` must be <= n_samples=6 but got value 1000000000" + match = "The expected number of unique samples" with pytest.raises(ValueError, match=match): est.fit(X, y)