Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Oct 24, 2023
1 parent e894708 commit 2887909
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 70 deletions.
108 changes: 49 additions & 59 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _is_fitted, check_X_y

from sktree import HonestForestClassifier
from sktree._lib.sklearn.ensemble._forest import (
ForestClassifier,
ForestRegressor,
Expand Down Expand Up @@ -117,13 +118,15 @@ def __init__(
stratify=False,
sample_dataset_per_tree=False,
permute_forest_fraction=None,
train_test_split=True,
):
self.estimator = estimator
self.random_state = random_state
self.verbose = verbose
self.test_size = test_size
self.stratify = stratify

self.train_test_split = train_test_split
# XXX: possibly removing these parameters
self.sample_dataset_per_tree = sample_dataset_per_tree
self.permute_forest_fraction = permute_forest_fraction
Expand Down Expand Up @@ -250,6 +253,10 @@ def train_test_samples_(self):
if self._n_samples_ is None:
raise RuntimeError("The estimator must be fitted before accessing this attribute.")

# we are not train/test splitting, then
if not self.train_test_split:
return [(np.arange(self._n_samples_, dtype=int), np.array([], dtype=int)) for _ in range(len(self.estimator_.estimators_))]

# Stratifier uses a cached _y attribute if available
stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None

Expand Down Expand Up @@ -303,6 +310,9 @@ def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike =
f"y must have type {self._type_of_target_}, got {type_of_target(y)}. "
f"If running on a new dataset, call the 'reset' method."
)

if not self.train_test_split and not isinstance(self.estimator, HonestForestClassifier):
raise RuntimeError(f'Train test split must occur if not using honest forest classifier.')

return X, y, covariate_index

Expand Down Expand Up @@ -592,12 +602,16 @@ class FeatureImportanceForestRegressor(BaseForestHT):
test_size : float, default=0.2
Proportion of samples per tree to use for the test set.
permute_per_tree : bool, default=True
Whether to permute the covariate index per tree or per forest.
sample_dataset_per_tree : bool, default=False
Whether to sample the dataset per tree or per forest.
permute_forest_fraction : float, default=None
The fraction of trees to permute the covariate index for. If None, then
just one permutation is performed.
train_test_split : bool, default=True
Whether to split the dataset before passing to the forest.
Attributes
----------
estimator_ : BaseForest
Expand Down Expand Up @@ -653,6 +667,7 @@ def __init__(
# permute_per_tree=False,
sample_dataset_per_tree=False,
permute_forest_fraction=None,
train_test_split=True,
):
super().__init__(
estimator=estimator,
Expand All @@ -662,6 +677,7 @@ def __init__(
# permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
permute_forest_fraction=permute_forest_fraction,
train_test_split=train_test_split,
)

def _get_estimator(self):
Expand Down Expand Up @@ -740,25 +756,6 @@ def _statistic(
# both sampling dataset per tree or permuting per tree requires us to bypass the
# sklearn API to fit each tree individually
if self.sample_dataset_per_tree or self.permute_forest_fraction:
# Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose, prefer="threads")(
# delayed(_parallel_build_trees_and_compute_posteriors)(
# estimator,
# idx,
# indices_train,
# indices_test,
# X,
# y,
# covariate_index,
# posterior_arr,
# False,
# self.permute_per_tree,
# self._type_of_target_,
# )
# for idx, (indices_train, indices_test) in enumerate(
# self._get_estimators_indices(sample_separate=True)
# )
# )

if self.permute_forest_fraction and covariate_index is not None:
random_states = [tree.random_state for tree in estimator.estimators_]
else:
Expand Down Expand Up @@ -802,24 +799,29 @@ def _statistic(
y_train = y_train.ravel()
estimator.fit(X_train, y_train)

# construct posterior array for all trees (n_trees, n_samples_test, n_outputs)
# for itree, tree in enumerate(estimator.estimators_):
# posterior_arr[itree, indices_test, ...] = tree.predict(X_test).reshape(
# -1, tree.n_outputs_
# )

# set variables to compute metric
samples = indices_test
y_true_final = y_test

# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)
# TODO: probably a more elegant way of doing this
if self.train_test_split:
# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)
else:
# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test)
for idx, honest_tree in enumerate(estimator.estimators_)
)
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)

# determine if there are any nans in the final posterior array, when
# averaged over the trees
Expand Down Expand Up @@ -880,14 +882,18 @@ class FeatureImportanceForestClassifier(BaseForestHT):
test_size : float, default=0.2
Proportion of samples per tree to use for the test set.
permute_per_tree : bool, default=True
Whether to permute the covariate index per tree or per forest.
stratify : bool, default=True
Whether to stratify the samples by class labels.
sample_dataset_per_tree : bool, default=False
Whether to sample the dataset per tree or per forest.
stratify : bool, default=True
Whether to stratify the samples by class labels.
permute_forest_fraction : float, default=None
The fraction of trees to permute the covariate index for. If None, then
just one permutation is performed.
train_test_split : bool, default=True
Whether to split the data into train/test before passing to the forest.
Attributes
----------
Expand Down Expand Up @@ -940,9 +946,10 @@ def __init__(
verbose=0,
test_size=0.2,
# permute_per_tree=False,
sample_dataset_per_tree=False,
stratify=True,
sample_dataset_per_tree=False,
permute_forest_fraction=None,
train_test_split=True,
):
super().__init__(
estimator=estimator,
Expand All @@ -952,6 +959,7 @@ def __init__(
# permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
stratify=stratify,
train_test_split=train_test_split,
permute_forest_fraction=permute_forest_fraction,
)

Expand Down Expand Up @@ -999,24 +1007,6 @@ def _statistic(
# both sampling dataset per tree or permuting per tree requires us to bypass the
# sklearn API to fit each tree individually
if self.sample_dataset_per_tree or self.permute_forest_fraction:
# Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose, prefer="threads")(
# delayed(_parallel_build_trees_and_compute_posteriors)(
# estimator,
# idx,
# indices_train,
# indices_test,
# X,
# y,
# covariate_index,
# posterior_arr,
# predict_posteriors,
# self.permute_per_tree,
# self._type_of_target_,
# )
# for idx, (indices_train, indices_test) in enumerate(
# self._get_estimators_indices(sample_separate=True)
# )
# )
if self.permute_forest_fraction and covariate_index is not None:
random_states = [tree.random_state for tree in estimator.estimators_]
else:
Expand Down
17 changes: 6 additions & 11 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,8 @@ def test_pickle(tmpdir):

@pytest.mark.parametrize(
"permute_forest_fraction",
[
None,
# 0.5
],
ids=[
"no_permute"
# "permute_forest_fraction",
],
[None, 0.5],
ids=["no_permute", "permute_forest_fraction"],
)
@pytest.mark.parametrize(
"sample_dataset_per_tree", [True, False], ids=["sample_dataset_per_tree", "no_sample_dataset"]
Expand Down Expand Up @@ -363,11 +357,12 @@ def test_sample_size_consistency_of_estimator_indices_(
_, posteriors, samples = clf.statistic(
X, y, covariate_index=None, return_posteriors=True, metric="mi"
)
print(clf._seeds)
if sample_dataset_per_tree:

if sample_dataset_per_tree or permute_forest_fraction is not None:
# check the non-nans
non_nan_idx = _non_nan_samples(posteriors)
assert clf.n_samples_test_ == n_samples, f"{clf.n_samples_test_} != {n_samples}"
if sample_dataset_per_tree:
assert clf.n_samples_test_ == n_samples, f"{clf.n_samples_test_} != {n_samples}"

sorted_sample_idx = sorted(np.unique(samples))
sorted_est_samples_idx = sorted(
Expand Down

0 comments on commit 2887909

Please sign in to comment.