From 739c7bef58586b73f43fae52bff30d611a375f29 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 24 Oct 2023 17:35:38 -0400 Subject: [PATCH] Adding ability to turn off train/test split Signed-off-by: Adam Li --- sktree/stats/forestht.py | 79 ++++++++++++++++++----------- sktree/stats/tests/test_forestht.py | 54 ++++++++++++++++++++ 2 files changed, 104 insertions(+), 29 deletions(-) diff --git a/sktree/stats/forestht.py b/sktree/stats/forestht.py index c1f9026fd..7f35ff9f4 100644 --- a/sktree/stats/forestht.py +++ b/sktree/stats/forestht.py @@ -10,7 +10,6 @@ from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import _is_fitted, check_X_y -from sktree import HonestForestClassifier from sktree._lib.sklearn.ensemble._forest import ( ForestClassifier, ForestRegressor, @@ -19,6 +18,7 @@ _get_n_samples_bootstrap, _parallel_build_trees, ) +from sktree.ensemble._honest_forest import HonestForestClassifier from sktree.tree import DecisionTreeClassifier, DecisionTreeRegressor from sktree.tree._classes import DTYPE @@ -253,9 +253,12 @@ def train_test_samples_(self): if self._n_samples_ is None: raise RuntimeError("The estimator must be fitted before accessing this attribute.") - # we are not train/test splitting, then + # we are not train/test splitting, then if not self.train_test_split: - return [(np.arange(self._n_samples_, dtype=int), np.array([], dtype=int)) for _ in range(len(self.estimator_.estimators_))] + return [ + (np.arange(self._n_samples_, dtype=int), np.array([], dtype=int)) + for _ in range(len(self.estimator_.estimators_)) + ] # Stratifier uses a cached _y attribute if available stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None @@ -310,9 +313,11 @@ def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike = f"y must have type {self._type_of_target_}, got {type_of_target(y)}. " f"If running on a new dataset, call the 'reset' method." ) - + if not self.train_test_split and not isinstance(self.estimator, HonestForestClassifier): - raise RuntimeError(f'Train test split must occur if not using honest forest classifier.') + raise RuntimeError( + "Train test split must occur if not using honest forest classifier." + ) return X, y, covariate_index @@ -807,21 +812,24 @@ def _statistic( if self.train_test_split: # accumulate the predictions across all trees all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( - delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test) + delayed(_parallel_predict_proba)( + estimator.estimators_[idx].predict, X, indices_test + ) for idx, (_, indices_test) in enumerate(self.train_test_samples_) ) for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)): _, indices_test = est_indices posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_) else: + all_indices = np.arange(self._n_samples_, dtype=int) + # accumulate the predictions across all trees all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( - delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test) - for idx, honest_tree in enumerate(estimator.estimators_) + delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, all_indices) + for idx in range(len(estimator.estimators_)) ) - for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)): - _, indices_test = est_indices - posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_) + for itree, proba in enumerate(all_proba): + posterior_arr[itree, ...] = proba.reshape(-1, estimator.n_outputs_) # determine if there are any nans in the final posterior array, when # averaged over the trees @@ -1050,29 +1058,37 @@ def _statistic( y_train = y_train.ravel() estimator.fit(X_train, y_train) - # construct posterior array for all trees (n_trees, n_samples_test, n_outputs) - # for itree, tree in enumerate(estimator.estimators_): - # if predict_posteriors: - # # XXX: currently assumes n_outputs_ == 1 - # posterior_arr[itree, indices_test, ...] = tree.predict_proba(X_test).reshape( - # -1, tree.n_classes_ - # ) - # else: - # posterior_arr[itree, indices_test, ...] = tree.predict(X_test).reshape( - # -1, tree.n_outputs_ - # ) - # set variables to compute metric samples = indices_test # list of tree outputs. Each tree output is (n_samples, n_outputs), or (n_samples,) if predict_posteriors: - all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( - delayed(_parallel_predict_proba)( - estimator.estimators_[idx].predict_proba, X, indices_test + # all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( + # delayed(_parallel_predict_proba)( + # estimator.estimators_[idx].predict_proba, X, indices_test + # ) + # for idx, (_, indices_test) in enumerate(self.train_test_samples_) + # ) + + # TODO: probably a more elegant way of doing this + if self.train_test_split: + # accumulate the predictions across all trees + all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( + delayed(_parallel_predict_proba)( + estimator.estimators_[idx].predict_proba, X, indices_test + ) + for idx, (_, indices_test) in enumerate(self.train_test_samples_) + ) + else: + all_indices = np.arange(self._n_samples_, dtype=int) + + # accumulate the predictions across all trees + all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( + delayed(_parallel_predict_proba)( + estimator.estimators_[idx].predict_proba, X, all_indices + ) + for idx in range(len(estimator.estimators_)) ) - for idx, (_, indices_test) in enumerate(self.train_test_samples_) - ) else: all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( delayed(_parallel_predict_proba)( @@ -1084,7 +1100,12 @@ def _statistic( _, indices_test = est_indices if predict_posteriors: - posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_classes_) + if self.train_test_split: + posterior_arr[itree, indices_test, ...] = proba.reshape( + -1, estimator.n_classes_ + ) + else: + posterior_arr[itree, ...] = proba.reshape(-1, estimator.n_classes_) else: posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_) diff --git a/sktree/stats/tests/test_forestht.py b/sktree/stats/tests/test_forestht.py index 4343a27af..39d5d9468 100644 --- a/sktree/stats/tests/test_forestht.py +++ b/sktree/stats/tests/test_forestht.py @@ -512,3 +512,57 @@ def test_small_dataset_dependent(seed): stat, pvalue = clf.test(X, y, metric="mi") assert pvalue <= 0.05 + + +@flaky(max_runs=3) +def test_no_traintest_split(): + n_samples = 500 + n_features = 5 + rng = np.random.default_rng(seed) + + X = rng.uniform(size=(n_samples, n_features)) + X = rng.uniform(size=(n_samples // 2, n_features)) + X2 = X * 2 + X = np.vstack([X, X2]) + y = np.vstack( + [np.zeros((n_samples // 2, 1)), np.ones((n_samples // 2, 1))] + ) # Binary classification + + clf = FeatureImportanceForestClassifier( + estimator=HonestForestClassifier( + n_estimators=50, + max_features=n_features, + random_state=seed, + n_jobs=1, + honest_fraction=0.5, + ), + test_size=0.2, + train_test_split=False, + permute_forest_fraction=None, + sample_dataset_per_tree=False, + ) + stat, pvalue = clf.test(X, y, covariate_index=[1, 2], metric="mi") + + # since no train-test split, the training is all the data and the testing is none of the data + assert_array_equal(clf.train_test_samples_[0][0], np.arange(n_samples)) + assert_array_equal(clf.train_test_samples_[0][1], np.array([])) + + assert ~np.isnan(pvalue) + assert ~np.isnan(stat) + assert pvalue <= 0.05, f"{pvalue}" + + stat, pvalue = clf.test(X, y, metric="mi") + assert pvalue <= 0.05, f"{pvalue}" + + X = rng.uniform(size=(n_samples, n_features)) + y = rng.integers(0, 2, size=n_samples) # Binary classification + clf.reset() + + stat, pvalue = clf.test(X, y, metric="mi") + assert_almost_equal(stat, 0.0, decimal=1) + assert pvalue > 0.05, f"{pvalue}" + + stat, pvalue = clf.test(X, y, covariate_index=[1, 2], metric="mi") + assert ~np.isnan(pvalue) + assert ~np.isnan(stat) + assert pvalue > 0.05, f"{pvalue}"