Skip to content

Commit

Permalink
ENH Adding estimators_samples_ attribute to forest models (scikit-lea…
Browse files Browse the repository at this point in the history
…rn#26736)

Signed-off-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
4 people authored Nov 3, 2023
1 parent a1e263a commit 3737909
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 4 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ Changelog
- |API| In :class:`ensemble.AdaBoostClassifier`, the `algorithm` argument `SAMME.R` was
deprecated and will be removed in 1.6. :pr:`26830` by :user:`Stefanie Senger
<StefanieSenger>`.
- |Enhancement| A fitted property, ``estimators_samples_``, was added to all Forest methods,
including
:class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`,
which allows to retrieve the training sample indices used for each tree estimator.
:pr:`26736` by :user:`Adam Li <adam2392>`.

:mod:`sklearn.feature_selection`
................................
Expand Down
68 changes: 66 additions & 2 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap):
Private function used to _parallel_build_trees function."""

random_instance = check_random_state(random_state)
sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap)
sample_indices = random_instance.randint(
0, n_samples, n_samples_bootstrap, dtype=np.int32
)

return sample_indices

Expand Down Expand Up @@ -416,7 +418,7 @@ def fit(self, X, y, sample_weight=None):
"is necessary for Poisson regression."
)

self.n_outputs_ = y.shape[1]
self._n_samples, self.n_outputs_ = y.shape

y, expanded_class_weight = self._validate_y_class_weight(y)

Expand All @@ -442,6 +444,8 @@ def fit(self, X, y, sample_weight=None):
else:
n_samples_bootstrap = None

self._n_samples_bootstrap = n_samples_bootstrap

self._validate_estimator()

if not self.bootstrap and self.oob_score:
Expand Down Expand Up @@ -679,6 +683,36 @@ def feature_importances_(self):
all_importances = np.mean(all_importances, axis=0, dtype=np.float64)
return all_importances / np.sum(all_importances)

def _get_estimators_indices(self):
# Get drawn indices along both sample and feature axes
for tree in self.estimators_:
if not self.bootstrap:
yield np.arange(self._n_samples, dtype=np.int32)
else:
# tree.random_state is actually an immutable integer seed rather
# than a mutable RandomState instance, so it's safe to use it
# repeatedly when calling this property.
seed = tree.random_state
# Operations accessing random_state must be performed identically
# to those in `_parallel_build_trees()`
yield _generate_sample_indices(
seed, self._n_samples, self._n_samples_bootstrap
)

@property
def estimators_samples_(self):
"""The subset of drawn samples for each base estimator.
Returns a dynamically generated list of indices identifying
the samples used for fitting each member of the ensemble, i.e.,
the in-bag samples.
Note: the list is re-created at each call to the property in order
to reduce the object memory footprint by not storing the sampling
data. Thus fetching the property may be slower than expected.
"""
return [sample_indices for sample_indices in self._get_estimators_indices()]

def _more_tags(self):
# Only the criterion is required to determine if the tree supports
# missing values
Expand Down Expand Up @@ -1406,6 +1440,12 @@ class labels (multi-output problem).
`oob_decision_function_` might contain NaN. This attribute exists
only when ``oob_score`` is True.
estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. versionadded:: 1.4
See Also
--------
sklearn.tree.DecisionTreeClassifier : A decision tree classifier.
Expand Down Expand Up @@ -1767,6 +1807,12 @@ class RandomForestRegressor(ForestRegressor):
Prediction computed with out-of-bag estimate on the training set.
This attribute exists only when ``oob_score`` is True.
estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. versionadded:: 1.4
See Also
--------
sklearn.tree.DecisionTreeRegressor : A decision tree regressor.
Expand Down Expand Up @@ -2149,6 +2195,12 @@ class labels (multi-output problem).
`oob_decision_function_` might contain NaN. This attribute exists
only when ``oob_score`` is True.
estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. versionadded:: 1.4
See Also
--------
ExtraTreesRegressor : An extra-trees regressor with random splits.
Expand Down Expand Up @@ -2495,6 +2547,12 @@ class ExtraTreesRegressor(ForestRegressor):
Prediction computed with out-of-bag estimate on the training set.
This attribute exists only when ``oob_score`` is True.
estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. versionadded:: 1.4
See Also
--------
ExtraTreesClassifier : An extra-trees classifier with random splits.
Expand Down Expand Up @@ -2742,6 +2800,12 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest):
one_hot_encoder_ : OneHotEncoder instance
One-hot encoder used to create the sparse embedding.
estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. versionadded:: 1.4
See Also
--------
ExtraTreesClassifier : An extra-trees classifier.
Expand Down
62 changes: 60 additions & 2 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from scipy.special import comb

import sklearn
from sklearn import datasets
from sklearn.datasets import make_classification
from sklearn import clone, datasets
from sklearn.datasets import make_classification, make_hastie_10_2
from sklearn.decomposition import TruncatedSVD
from sklearn.dummy import DummyRegressor
from sklearn.ensemble import (
Expand All @@ -46,6 +46,7 @@
from sklearn.tree._classes import SPARSE_SPLITTERS
from sklearn.utils._testing import (
_convert_container,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
Expand Down Expand Up @@ -1686,6 +1687,63 @@ def test_round_samples_to_one_when_samples_too_low(class_weight):
forest.fit(X, y)


@pytest.mark.parametrize("seed", [None, 1])
@pytest.mark.parametrize("bootstrap", [True, False])
@pytest.mark.parametrize("ForestClass", FOREST_CLASSIFIERS_REGRESSORS.values())
def test_estimators_samples(ForestClass, bootstrap, seed):
"""Estimators_samples_ property should be consistent.
Tests consistency across fits and whether or not the seed for the random generator
is set.
"""
X, y = make_hastie_10_2(n_samples=200, random_state=1)

if bootstrap:
max_samples = 0.5
else:
max_samples = None
est = ForestClass(
n_estimators=10,
max_samples=max_samples,
max_features=0.5,
random_state=seed,
bootstrap=bootstrap,
)
est.fit(X, y)

estimators_samples = est.estimators_samples_.copy()

# Test repeated calls result in same set of indices
assert_array_equal(estimators_samples, est.estimators_samples_)
estimators = est.estimators_

assert isinstance(estimators_samples, list)
assert len(estimators_samples) == len(estimators)
assert estimators_samples[0].dtype == np.int32

for i in range(len(estimators)):
if bootstrap:
assert len(estimators_samples[i]) == len(X) // 2

# the bootstrap should be a resampling with replacement
assert len(np.unique(estimators_samples[i])) < len(estimators_samples[i])
else:
assert len(set(estimators_samples[i])) == len(X)

estimator_index = 0
estimator_samples = estimators_samples[estimator_index]
estimator = estimators[estimator_index]

X_train = X[estimator_samples]
y_train = y[estimator_samples]

orig_tree_values = estimator.tree_.value
estimator = clone(estimator)
estimator.fit(X_train, y_train)
new_tree_values = estimator.tree_.value
assert_allclose(orig_tree_values, new_tree_values)


@pytest.mark.parametrize(
"make_data, Forest",
[
Expand Down

0 comments on commit 3737909

Please sign in to comment.