Skip to content

Commit

Permalink
Stratify sampling when split train/test data (#143)
Browse files Browse the repository at this point in the history
* Stratify sampling when split train/test data

---------

Co-authored-by: Haoyin Xu <haoyinxu@gmail.com>
Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Sambit Panda <36676569+sampan501@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 19, 2023
1 parent 860d197 commit 359ea75
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 27 deletions.
1 change: 1 addition & 0 deletions doc/whats_new/_contributors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
.. _SUKI-O : https://github.com/SUKI-O
.. _Ronan Perry : https://rflperry.github.io/
.. _Haoyin Xu : https://github.com/PSSF23
.. _Yuxin Bai : https://github.com/YuxinB
3 changes: 2 additions & 1 deletion doc/whats_new/v0.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Changelog
- |Fix| Fixes a bug in consistency of train/test samples when ``random_state`` is not set in FeatureImportanceForestClassifier and FeatureImportanceForestRegressor, by `Adam Li`_ (:pr:`135`)
- |Fix| Fixes a bug where covariate indices were not shuffled by default when running FeatureImportanceForestClassifier and FeatureImportanceForestRegressor test methods, by `Sambit Panda`_ (:pr:`140`)
- |Enhancement| Add multi-view splitter for axis-aligned decision trees, by `Adam Li`_ (:pr:`129`)
- |Enhancement| Add stratified sampling option to ``FeatureImportance*`` via the ``stratify`` keyword argument, by `Yuxin Bai`_ (:pr:`143`)

Code and Documentation Contributors
-----------------------------------
Expand All @@ -24,4 +25,4 @@ the project since version inception, including:

* `Adam Li`_
* `Sambit Panda`_

* `Yuxin Bai`_
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
===========================================================
Mutual Information for Gigantic Hypothesis Testing (MIGHT)
===========================================================
=========================================================
Mutual Information for Genuine Hypothesis Testing (MIGHT)
=========================================================
An example using :class:`~sktree.stats.FeatureImportanceForestClassifier` for nonparametric
multivariate hypothesis test, on simulated datasets. Here, we present a simulation
Expand Down Expand Up @@ -49,8 +49,8 @@
# We simulate the two feature sets, and the target variable. We then combine them
# into a single dataset to perform hypothesis testing.

n_samples = 1000
n_features_set = 500
n_samples = 2000
n_features_set = 20
mean = 1.0
sigma = 2.0
beta = 5.0
Expand Down Expand Up @@ -91,7 +91,7 @@
# computed as the proportion of samples in the null distribution that are less than the
# observed test statistic.

n_estimators = 200
n_estimators = 100
max_features = "sqrt"
test_size = 0.2
n_repeats = 1000
Expand All @@ -103,12 +103,12 @@
max_features=max_features,
tree_estimator=DecisionTreeClassifier(),
random_state=seed,
honest_fraction=0.7,
honest_fraction=0.25,
n_jobs=n_jobs,
),
random_state=seed,
test_size=test_size,
permute_per_tree=True,
permute_per_tree=False,
sample_dataset_per_tree=False,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
===============================================================================
Mutual Information for Gigantic Hypothesis Testing (MIGHT) with Imbalanced Data
===============================================================================
==============================================================================
Mutual Information for Genuine Hypothesis Testing (MIGHT) with Imbalanced Data
==============================================================================
Here, we demonstrate how to do hypothesis testing on highly imbalanced data
in terms of their feature-set dimensionalities.
Expand All @@ -17,7 +17,7 @@
For other examples of hypothesis testing, see the following:
- :ref:`sphx_glr_auto_examples_hypothesis_testing_plot_MI_gigantic_hypothesis_testing_forest.py`
- :ref:`sphx_glr_auto_examples_hypothesis_testing_plot_MI_genuine_hypothesis_testing_forest.py`
- :ref:`sphx_glr_auto_examples_hypothesis_testing_plot_might_auc.py`
For more information on the multi-view decision-tree, see
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.25
scipy>=1.11
scikit-learn>=1.3.1

42 changes: 28 additions & 14 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def __init__(
test_size=0.2,
permute_per_tree=True,
sample_dataset_per_tree=True,
stratify=False,
):
self.estimator = estimator
self.random_state = random_state
self.verbose = verbose
self.test_size = test_size
self.permute_per_tree = permute_per_tree
self.sample_dataset_per_tree = sample_dataset_per_tree
self.stratify = stratify

self.n_samples_test_ = None
self._n_samples_ = None
Expand Down Expand Up @@ -160,8 +162,9 @@ def reset(self):
self.n_features_in_ = None
self._is_fitted = False
self._seeds = None
self._y = None

def _get_estimators_indices(self, sample_separate=False):
def _get_estimators_indices(self, stratifier=None, sample_separate=False):
indices = np.arange(self._n_samples_, dtype=int)

# Get drawn indices along both sample and feature axes
Expand Down Expand Up @@ -191,7 +194,11 @@ def _get_estimators_indices(self, sample_separate=False):
# Operations accessing random_state must be performed identically
# to those in `_parallel_build_trees()`
indices_train, indices_test = train_test_split(
indices, test_size=self.test_size, shuffle=True, random_state=seed
indices,
test_size=self.test_size,
shuffle=True,
stratify=stratifier,
random_state=seed,
)

yield indices_train, indices_test
Expand All @@ -202,12 +209,13 @@ def _get_estimators_indices(self, sample_separate=False):
else:
self._seeds = self.estimator_.random_state

# TODO: make random_state consistent
indices_train, indices_test = train_test_split(
indices,
test_size=self.test_size,
stratify=stratifier,
random_state=self._seeds,
)

for _ in self.estimator_.estimators_:
yield indices_train, indices_test

Expand All @@ -227,9 +235,12 @@ def train_test_samples_(self):
if self._n_samples_ is None:
raise RuntimeError("The estimator must be fitted before accessing this attribute.")

# Stratifier uses a cached _y attribute if available
stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None

return [
(indices_train, indices_test)
for indices_train, indices_test in self._get_estimators_indices()
for indices_train, indices_test in self._get_estimators_indices(stratifier=stratifier)
]

def _statistic(
Expand Down Expand Up @@ -329,6 +340,8 @@ def statistic(

if self._n_samples_ is None:
self._n_samples_, self.n_features_in_ = X.shape

# Infer type of target y
if self._type_of_target_ is None:
self._type_of_target_ = type_of_target(y)

Expand All @@ -339,9 +352,9 @@ def statistic(
self.permuted_estimator_ = self._get_estimator()
estimator = self.permuted_estimator_

# Infer type of target y
if not hasattr(self, "_type_of_target"):
self._type_of_target_ = type_of_target(y)
# Store a cache of the y variable
if is_classifier(self._get_estimator()):
self._y = y.copy()

# XXX: this can be improved as an extra fit can be avoided, by just doing error-checking
# and then setting the internal meta data structures
Expand Down Expand Up @@ -462,10 +475,10 @@ def test(
observe_posteriors = self.observe_posteriors_
observe_stat = self.observe_stat_

# next permute the data
if covariate_index is None:
covariate_index = np.arange(X.shape[1], dtype=int)

# next permute the data
permute_stat, permute_posteriors, permute_samples = self.statistic(
X,
y,
Expand Down Expand Up @@ -724,9 +737,7 @@ def _statistic(
self.permute_per_tree,
self._type_of_target_,
)
for idx, (indices_train, indices_test) in enumerate(
self._get_estimators_indices(sample_separate=True)
)
for idx, (indices_train, indices_test) in enumerate(self.train_test_samples_)
)
else:
# fitting a forest will only get one unique train/test split
Expand Down Expand Up @@ -825,6 +836,9 @@ class FeatureImportanceForestClassifier(BaseForestHT):
sample_dataset_per_tree : bool, default=False
Whether to sample the dataset per tree or per forest.
stratify : bool, default=True
Whether to stratify the samples by class labels.
Attributes
----------
estimator_ : BaseForest
Expand Down Expand Up @@ -877,6 +891,7 @@ def __init__(
test_size=0.2,
permute_per_tree=True,
sample_dataset_per_tree=True,
stratify=True,
):
super().__init__(
estimator=estimator,
Expand All @@ -885,6 +900,7 @@ def __init__(
test_size=test_size,
permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
stratify=stratify,
)

def _get_estimator(self):
Expand Down Expand Up @@ -945,9 +961,7 @@ def _statistic(
self.permute_per_tree,
self._type_of_target_,
)
for idx, (indices_train, indices_test) in enumerate(
self._get_estimators_indices(sample_separate=True)
)
for idx, (indices_train, indices_test) in enumerate(self.train_test_samples_)
)
else:
# fitting a forest will only get one unique train/test split
Expand Down
32 changes: 32 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,38 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], [0, 1.0], metric="mi")


@pytest.mark.parametrize("sample_dataset_per_tree", [True, False])
def test_featureimportance_forest_stratified(sample_dataset_per_tree):
est = FeatureImportanceForestClassifier(
estimator=RandomForestClassifier(
n_estimators=10,
random_state=seed,
),
permute_per_tree=True,
test_size=0.7,
random_state=seed,
sample_dataset_per_tree=sample_dataset_per_tree,
)
n_samples = 100
est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mi")

_, indices_test = est.train_test_samples_[0]
y_test = iris_y[indices_test]

assert len(y_test[y_test == 0]) == len(y_test[y_test == 1]), (
f"{len(y_test[y_test==0])} " f"{len(y_test[y_test==1])}"
)

est.test(iris_X[:n_samples], iris_y[:n_samples], [0, 1], n_repeats=10, metric="mi")

_, indices_test = est.train_test_samples_[0]
y_test = iris_y[indices_test]

assert len(y_test[y_test == 0]) == len(y_test[y_test == 1]), (
f"{len(y_test[y_test==0])} " f"{len(y_test[y_test==1])}"
)


def test_featureimportance_forest_errors():
permute_per_tree = False
sample_dataset_per_tree = True
Expand Down

0 comments on commit 359ea75

Please sign in to comment.