diff --git a/doc/whats_new/_contributors.rst b/doc/whats_new/_contributors.rst index 3e5ca2110..eb441d66d 100644 --- a/doc/whats_new/_contributors.rst +++ b/doc/whats_new/_contributors.rst @@ -26,3 +26,4 @@ .. _SUKI-O : https://github.com/SUKI-O .. _Ronan Perry : https://rflperry.github.io/ .. _Haoyin Xu : https://github.com/PSSF23 +.. _Yuxin Bai : https://github.com/YuxinB diff --git a/doc/whats_new/v0.3.rst b/doc/whats_new/v0.3.rst index fec97bb01..7b163ef19 100644 --- a/doc/whats_new/v0.3.rst +++ b/doc/whats_new/v0.3.rst @@ -15,6 +15,7 @@ Changelog - |Fix| Fixes a bug in consistency of train/test samples when ``random_state`` is not set in FeatureImportanceForestClassifier and FeatureImportanceForestRegressor, by `Adam Li`_ (:pr:`135`) - |Fix| Fixes a bug where covariate indices were not shuffled by default when running FeatureImportanceForestClassifier and FeatureImportanceForestRegressor test methods, by `Sambit Panda`_ (:pr:`140`) - |Enhancement| Add multi-view splitter for axis-aligned decision trees, by `Adam Li`_ (:pr:`129`) +- |Enhancement| Add stratified sampling option to ``FeatureImportance*`` via the ``stratify`` keyword argument, by `Yuxin Bai`_ (:pr:`143`) Code and Documentation Contributors ----------------------------------- @@ -24,4 +25,4 @@ the project since version inception, including: * `Adam Li`_ * `Sambit Panda`_ - +* `Yuxin Bai`_ diff --git a/examples/hypothesis_testing/plot_MI_gigantic_hypothesis_testing_forest.py b/examples/hypothesis_testing/plot_MI_genuine_hypothesis_testing_forest.py similarity index 94% rename from examples/hypothesis_testing/plot_MI_gigantic_hypothesis_testing_forest.py rename to examples/hypothesis_testing/plot_MI_genuine_hypothesis_testing_forest.py index 423bc63dc..e6831a9e7 100644 --- a/examples/hypothesis_testing/plot_MI_gigantic_hypothesis_testing_forest.py +++ b/examples/hypothesis_testing/plot_MI_genuine_hypothesis_testing_forest.py @@ -1,7 +1,7 @@ """ -=========================================================== -Mutual Information for Gigantic Hypothesis Testing (MIGHT) -=========================================================== +========================================================= +Mutual Information for Genuine Hypothesis Testing (MIGHT) +========================================================= An example using :class:`~sktree.stats.FeatureImportanceForestClassifier` for nonparametric multivariate hypothesis test, on simulated datasets. Here, we present a simulation @@ -49,8 +49,8 @@ # We simulate the two feature sets, and the target variable. We then combine them # into a single dataset to perform hypothesis testing. -n_samples = 1000 -n_features_set = 500 +n_samples = 2000 +n_features_set = 20 mean = 1.0 sigma = 2.0 beta = 5.0 @@ -91,7 +91,7 @@ # computed as the proportion of samples in the null distribution that are less than the # observed test statistic. -n_estimators = 200 +n_estimators = 100 max_features = "sqrt" test_size = 0.2 n_repeats = 1000 @@ -103,12 +103,12 @@ max_features=max_features, tree_estimator=DecisionTreeClassifier(), random_state=seed, - honest_fraction=0.7, + honest_fraction=0.25, n_jobs=n_jobs, ), random_state=seed, test_size=test_size, - permute_per_tree=True, + permute_per_tree=False, sample_dataset_per_tree=False, ) diff --git a/examples/hypothesis_testing/plot_MI_imbalanced_hyppo_testing.py b/examples/hypothesis_testing/plot_MI_imbalanced_hyppo_testing.py index 882f80c3d..c8a5478a4 100644 --- a/examples/hypothesis_testing/plot_MI_imbalanced_hyppo_testing.py +++ b/examples/hypothesis_testing/plot_MI_imbalanced_hyppo_testing.py @@ -1,7 +1,7 @@ """ -=============================================================================== -Mutual Information for Gigantic Hypothesis Testing (MIGHT) with Imbalanced Data -=============================================================================== +============================================================================== +Mutual Information for Genuine Hypothesis Testing (MIGHT) with Imbalanced Data +============================================================================== Here, we demonstrate how to do hypothesis testing on highly imbalanced data in terms of their feature-set dimensionalities. @@ -17,7 +17,7 @@ For other examples of hypothesis testing, see the following: -- :ref:`sphx_glr_auto_examples_hypothesis_testing_plot_MI_gigantic_hypothesis_testing_forest.py` +- :ref:`sphx_glr_auto_examples_hypothesis_testing_plot_MI_genuine_hypothesis_testing_forest.py` - :ref:`sphx_glr_auto_examples_hypothesis_testing_plot_might_auc.py` For more information on the multi-view decision-tree, see diff --git a/requirements.txt b/requirements.txt index 99963814e..978f90fce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.25 scipy>=1.11 scikit-learn>=1.3.1 + diff --git a/sktree/stats/forestht.py b/sktree/stats/forestht.py index 4d6dc7b77..56a044c5c 100644 --- a/sktree/stats/forestht.py +++ b/sktree/stats/forestht.py @@ -122,6 +122,7 @@ def __init__( test_size=0.2, permute_per_tree=True, sample_dataset_per_tree=True, + stratify=False, ): self.estimator = estimator self.random_state = random_state @@ -129,6 +130,7 @@ def __init__( self.test_size = test_size self.permute_per_tree = permute_per_tree self.sample_dataset_per_tree = sample_dataset_per_tree + self.stratify = stratify self.n_samples_test_ = None self._n_samples_ = None @@ -160,8 +162,9 @@ def reset(self): self.n_features_in_ = None self._is_fitted = False self._seeds = None + self._y = None - def _get_estimators_indices(self, sample_separate=False): + def _get_estimators_indices(self, stratifier=None, sample_separate=False): indices = np.arange(self._n_samples_, dtype=int) # Get drawn indices along both sample and feature axes @@ -191,7 +194,11 @@ def _get_estimators_indices(self, sample_separate=False): # Operations accessing random_state must be performed identically # to those in `_parallel_build_trees()` indices_train, indices_test = train_test_split( - indices, test_size=self.test_size, shuffle=True, random_state=seed + indices, + test_size=self.test_size, + shuffle=True, + stratify=stratifier, + random_state=seed, ) yield indices_train, indices_test @@ -202,12 +209,13 @@ def _get_estimators_indices(self, sample_separate=False): else: self._seeds = self.estimator_.random_state - # TODO: make random_state consistent indices_train, indices_test = train_test_split( indices, test_size=self.test_size, + stratify=stratifier, random_state=self._seeds, ) + for _ in self.estimator_.estimators_: yield indices_train, indices_test @@ -227,9 +235,12 @@ def train_test_samples_(self): if self._n_samples_ is None: raise RuntimeError("The estimator must be fitted before accessing this attribute.") + # Stratifier uses a cached _y attribute if available + stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None + return [ (indices_train, indices_test) - for indices_train, indices_test in self._get_estimators_indices() + for indices_train, indices_test in self._get_estimators_indices(stratifier=stratifier) ] def _statistic( @@ -329,6 +340,8 @@ def statistic( if self._n_samples_ is None: self._n_samples_, self.n_features_in_ = X.shape + + # Infer type of target y if self._type_of_target_ is None: self._type_of_target_ = type_of_target(y) @@ -339,9 +352,9 @@ def statistic( self.permuted_estimator_ = self._get_estimator() estimator = self.permuted_estimator_ - # Infer type of target y - if not hasattr(self, "_type_of_target"): - self._type_of_target_ = type_of_target(y) + # Store a cache of the y variable + if is_classifier(self._get_estimator()): + self._y = y.copy() # XXX: this can be improved as an extra fit can be avoided, by just doing error-checking # and then setting the internal meta data structures @@ -462,10 +475,10 @@ def test( observe_posteriors = self.observe_posteriors_ observe_stat = self.observe_stat_ - # next permute the data if covariate_index is None: covariate_index = np.arange(X.shape[1], dtype=int) + # next permute the data permute_stat, permute_posteriors, permute_samples = self.statistic( X, y, @@ -724,9 +737,7 @@ def _statistic( self.permute_per_tree, self._type_of_target_, ) - for idx, (indices_train, indices_test) in enumerate( - self._get_estimators_indices(sample_separate=True) - ) + for idx, (indices_train, indices_test) in enumerate(self.train_test_samples_) ) else: # fitting a forest will only get one unique train/test split @@ -825,6 +836,9 @@ class FeatureImportanceForestClassifier(BaseForestHT): sample_dataset_per_tree : bool, default=False Whether to sample the dataset per tree or per forest. + stratify : bool, default=True + Whether to stratify the samples by class labels. + Attributes ---------- estimator_ : BaseForest @@ -877,6 +891,7 @@ def __init__( test_size=0.2, permute_per_tree=True, sample_dataset_per_tree=True, + stratify=True, ): super().__init__( estimator=estimator, @@ -885,6 +900,7 @@ def __init__( test_size=test_size, permute_per_tree=permute_per_tree, sample_dataset_per_tree=sample_dataset_per_tree, + stratify=stratify, ) def _get_estimator(self): @@ -945,9 +961,7 @@ def _statistic( self.permute_per_tree, self._type_of_target_, ) - for idx, (indices_train, indices_test) in enumerate( - self._get_estimators_indices(sample_separate=True) - ) + for idx, (indices_train, indices_test) in enumerate(self.train_test_samples_) ) else: # fitting a forest will only get one unique train/test split diff --git a/sktree/stats/tests/test_forestht.py b/sktree/stats/tests/test_forestht.py index cecf34b8c..e71e5e09b 100644 --- a/sktree/stats/tests/test_forestht.py +++ b/sktree/stats/tests/test_forestht.py @@ -69,6 +69,38 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree): est.statistic(iris_X[:n_samples], iris_y[:n_samples], [0, 1.0], metric="mi") +@pytest.mark.parametrize("sample_dataset_per_tree", [True, False]) +def test_featureimportance_forest_stratified(sample_dataset_per_tree): + est = FeatureImportanceForestClassifier( + estimator=RandomForestClassifier( + n_estimators=10, + random_state=seed, + ), + permute_per_tree=True, + test_size=0.7, + random_state=seed, + sample_dataset_per_tree=sample_dataset_per_tree, + ) + n_samples = 100 + est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mi") + + _, indices_test = est.train_test_samples_[0] + y_test = iris_y[indices_test] + + assert len(y_test[y_test == 0]) == len(y_test[y_test == 1]), ( + f"{len(y_test[y_test==0])} " f"{len(y_test[y_test==1])}" + ) + + est.test(iris_X[:n_samples], iris_y[:n_samples], [0, 1], n_repeats=10, metric="mi") + + _, indices_test = est.train_test_samples_[0] + y_test = iris_y[indices_test] + + assert len(y_test[y_test == 0]) == len(y_test[y_test == 1]), ( + f"{len(y_test[y_test==0])} " f"{len(y_test[y_test==1])}" + ) + + def test_featureimportance_forest_errors(): permute_per_tree = False sample_dataset_per_tree = True