From 2887909692f3ef7a55bbd6f6b1ba113f110ed9d4 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 24 Oct 2023 13:16:17 -0400 Subject: [PATCH] WIP Signed-off-by: Adam Li --- sktree/stats/forestht.py | 108 +++++++++++++--------------- sktree/stats/tests/test_forestht.py | 17 ++--- 2 files changed, 55 insertions(+), 70 deletions(-) diff --git a/sktree/stats/forestht.py b/sktree/stats/forestht.py index 847557ac0..c1f9026fd 100644 --- a/sktree/stats/forestht.py +++ b/sktree/stats/forestht.py @@ -10,6 +10,7 @@ from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import _is_fitted, check_X_y +from sktree import HonestForestClassifier from sktree._lib.sklearn.ensemble._forest import ( ForestClassifier, ForestRegressor, @@ -117,6 +118,7 @@ def __init__( stratify=False, sample_dataset_per_tree=False, permute_forest_fraction=None, + train_test_split=True, ): self.estimator = estimator self.random_state = random_state @@ -124,6 +126,7 @@ def __init__( self.test_size = test_size self.stratify = stratify + self.train_test_split = train_test_split # XXX: possibly removing these parameters self.sample_dataset_per_tree = sample_dataset_per_tree self.permute_forest_fraction = permute_forest_fraction @@ -250,6 +253,10 @@ def train_test_samples_(self): if self._n_samples_ is None: raise RuntimeError("The estimator must be fitted before accessing this attribute.") + # we are not train/test splitting, then + if not self.train_test_split: + return [(np.arange(self._n_samples_, dtype=int), np.array([], dtype=int)) for _ in range(len(self.estimator_.estimators_))] + # Stratifier uses a cached _y attribute if available stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None @@ -303,6 +310,9 @@ def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike = f"y must have type {self._type_of_target_}, got {type_of_target(y)}. " f"If running on a new dataset, call the 'reset' method." ) + + if not self.train_test_split and not isinstance(self.estimator, HonestForestClassifier): + raise RuntimeError(f'Train test split must occur if not using honest forest classifier.') return X, y, covariate_index @@ -592,12 +602,16 @@ class FeatureImportanceForestRegressor(BaseForestHT): test_size : float, default=0.2 Proportion of samples per tree to use for the test set. - permute_per_tree : bool, default=True - Whether to permute the covariate index per tree or per forest. - sample_dataset_per_tree : bool, default=False Whether to sample the dataset per tree or per forest. + permute_forest_fraction : float, default=None + The fraction of trees to permute the covariate index for. If None, then + just one permutation is performed. + + train_test_split : bool, default=True + Whether to split the dataset before passing to the forest. + Attributes ---------- estimator_ : BaseForest @@ -653,6 +667,7 @@ def __init__( # permute_per_tree=False, sample_dataset_per_tree=False, permute_forest_fraction=None, + train_test_split=True, ): super().__init__( estimator=estimator, @@ -662,6 +677,7 @@ def __init__( # permute_per_tree=permute_per_tree, sample_dataset_per_tree=sample_dataset_per_tree, permute_forest_fraction=permute_forest_fraction, + train_test_split=train_test_split, ) def _get_estimator(self): @@ -740,25 +756,6 @@ def _statistic( # both sampling dataset per tree or permuting per tree requires us to bypass the # sklearn API to fit each tree individually if self.sample_dataset_per_tree or self.permute_forest_fraction: - # Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose, prefer="threads")( - # delayed(_parallel_build_trees_and_compute_posteriors)( - # estimator, - # idx, - # indices_train, - # indices_test, - # X, - # y, - # covariate_index, - # posterior_arr, - # False, - # self.permute_per_tree, - # self._type_of_target_, - # ) - # for idx, (indices_train, indices_test) in enumerate( - # self._get_estimators_indices(sample_separate=True) - # ) - # ) - if self.permute_forest_fraction and covariate_index is not None: random_states = [tree.random_state for tree in estimator.estimators_] else: @@ -802,24 +799,29 @@ def _statistic( y_train = y_train.ravel() estimator.fit(X_train, y_train) - # construct posterior array for all trees (n_trees, n_samples_test, n_outputs) - # for itree, tree in enumerate(estimator.estimators_): - # posterior_arr[itree, indices_test, ...] = tree.predict(X_test).reshape( - # -1, tree.n_outputs_ - # ) - # set variables to compute metric samples = indices_test y_true_final = y_test - # accumulate the predictions across all trees - all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( - delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test) - for idx, (_, indices_test) in enumerate(self.train_test_samples_) - ) - for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)): - _, indices_test = est_indices - posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_) + # TODO: probably a more elegant way of doing this + if self.train_test_split: + # accumulate the predictions across all trees + all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( + delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test) + for idx, (_, indices_test) in enumerate(self.train_test_samples_) + ) + for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)): + _, indices_test = est_indices + posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_) + else: + # accumulate the predictions across all trees + all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)( + delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test) + for idx, honest_tree in enumerate(estimator.estimators_) + ) + for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)): + _, indices_test = est_indices + posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_) # determine if there are any nans in the final posterior array, when # averaged over the trees @@ -880,14 +882,18 @@ class FeatureImportanceForestClassifier(BaseForestHT): test_size : float, default=0.2 Proportion of samples per tree to use for the test set. - permute_per_tree : bool, default=True - Whether to permute the covariate index per tree or per forest. + stratify : bool, default=True + Whether to stratify the samples by class labels. sample_dataset_per_tree : bool, default=False Whether to sample the dataset per tree or per forest. - stratify : bool, default=True - Whether to stratify the samples by class labels. + permute_forest_fraction : float, default=None + The fraction of trees to permute the covariate index for. If None, then + just one permutation is performed. + + train_test_split : bool, default=True + Whether to split the data into train/test before passing to the forest. Attributes ---------- @@ -940,9 +946,10 @@ def __init__( verbose=0, test_size=0.2, # permute_per_tree=False, - sample_dataset_per_tree=False, stratify=True, + sample_dataset_per_tree=False, permute_forest_fraction=None, + train_test_split=True, ): super().__init__( estimator=estimator, @@ -952,6 +959,7 @@ def __init__( # permute_per_tree=permute_per_tree, sample_dataset_per_tree=sample_dataset_per_tree, stratify=stratify, + train_test_split=train_test_split, permute_forest_fraction=permute_forest_fraction, ) @@ -999,24 +1007,6 @@ def _statistic( # both sampling dataset per tree or permuting per tree requires us to bypass the # sklearn API to fit each tree individually if self.sample_dataset_per_tree or self.permute_forest_fraction: - # Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose, prefer="threads")( - # delayed(_parallel_build_trees_and_compute_posteriors)( - # estimator, - # idx, - # indices_train, - # indices_test, - # X, - # y, - # covariate_index, - # posterior_arr, - # predict_posteriors, - # self.permute_per_tree, - # self._type_of_target_, - # ) - # for idx, (indices_train, indices_test) in enumerate( - # self._get_estimators_indices(sample_separate=True) - # ) - # ) if self.permute_forest_fraction and covariate_index is not None: random_states = [tree.random_state for tree in estimator.estimators_] else: diff --git a/sktree/stats/tests/test_forestht.py b/sktree/stats/tests/test_forestht.py index 6c6668563..4343a27af 100644 --- a/sktree/stats/tests/test_forestht.py +++ b/sktree/stats/tests/test_forestht.py @@ -328,14 +328,8 @@ def test_pickle(tmpdir): @pytest.mark.parametrize( "permute_forest_fraction", - [ - None, - # 0.5 - ], - ids=[ - "no_permute" - # "permute_forest_fraction", - ], + [None, 0.5], + ids=["no_permute", "permute_forest_fraction"], ) @pytest.mark.parametrize( "sample_dataset_per_tree", [True, False], ids=["sample_dataset_per_tree", "no_sample_dataset"] @@ -363,11 +357,12 @@ def test_sample_size_consistency_of_estimator_indices_( _, posteriors, samples = clf.statistic( X, y, covariate_index=None, return_posteriors=True, metric="mi" ) - print(clf._seeds) - if sample_dataset_per_tree: + + if sample_dataset_per_tree or permute_forest_fraction is not None: # check the non-nans non_nan_idx = _non_nan_samples(posteriors) - assert clf.n_samples_test_ == n_samples, f"{clf.n_samples_test_} != {n_samples}" + if sample_dataset_per_tree: + assert clf.n_samples_test_ == n_samples, f"{clf.n_samples_test_} != {n_samples}" sorted_sample_idx = sorted(np.unique(samples)) sorted_est_samples_idx = sorted(