Skip to content

Commit

Permalink
Allow max samples to be higher
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Feb 22, 2024
1 parent 3ec238b commit d48716a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
23 changes: 18 additions & 5 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,18 @@ def _get_n_samples_bootstrap(n_samples, max_samples):
"""
Get the number of samples in a bootstrap sample.
The expected total number of unique samples in a bootstrap sample is
required to be at most ``n_samples - 1``.
This is equivalent to the expected number of out-of-bag samples being at
least 1.
Parameters
----------
n_samples : int
Number of samples in the dataset.
max_samples : int or float
The maximum number of samples to draw from the total available:
- if float, this indicates a fraction of the total and should be
the interval `(0.0, 1.0]`;
- if float, this indicates a fraction of the total;
- if int, this indicates the exact number of samples;
- if None, this indicates the total number of samples.
Expand All @@ -124,12 +128,21 @@ def _get_n_samples_bootstrap(n_samples, max_samples):
return n_samples

if isinstance(max_samples, Integral):
if max_samples > n_samples:
msg = "`max_samples` must be <= n_samples={} but got value {}"
raise ValueError(msg.format(n_samples, max_samples))
expected_oob_samples = (1 - np.exp(-max_samples / n_samples)) * n_samples
if expected_oob_samples >= n_samples - 1:
raise ValueError(
"The expected number of unique samples in the bootstrap sample"
f" must be at most {n_samples - 1}. It is: {expected_oob_samples}"
)
return max_samples

if isinstance(max_samples, Real):
expected_oob_samples = (1 - np.exp(-max_samples)) * n_samples
if expected_oob_samples >= n_samples - 1:
raise ValueError(
"The expected number of unique samples in the bootstrap sample"
f" must be at most {n_samples - 1}. It is: {expected_oob_samples}"
)
return max(round(n_samples * max_samples), 1)


Expand Down
5 changes: 4 additions & 1 deletion sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,10 @@ def test_max_samples_bootstrap(name):
def test_large_max_samples_exception(name):
# Check invalid `max_samples`
est = FOREST_CLASSIFIERS_REGRESSORS[name](bootstrap=True, max_samples=int(1e9))
match = "`max_samples` must be <= n_samples=6 but got value 1000000000"
# TODO: remove the following line when the issue is fixed
# https://github.com/scikit-learn/scikit-learn/issues/28507
# match = "`max_samples` must be <= n_samples=6 but got value 1000000000"
match = "The expected number of unique samples"
with pytest.raises(ValueError, match=match):
est.fit(X, y)

Expand Down

0 comments on commit d48716a

Please sign in to comment.