Skip to content

Commit

Permalink
Adding ability to turn off train/test split
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Oct 24, 2023
1 parent 2887909 commit 739c7be
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 29 deletions.
79 changes: 50 additions & 29 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _is_fitted, check_X_y

from sktree import HonestForestClassifier
from sktree._lib.sklearn.ensemble._forest import (
ForestClassifier,
ForestRegressor,
Expand All @@ -19,6 +18,7 @@
_get_n_samples_bootstrap,
_parallel_build_trees,
)
from sktree.ensemble._honest_forest import HonestForestClassifier
from sktree.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sktree.tree._classes import DTYPE

Expand Down Expand Up @@ -253,9 +253,12 @@ def train_test_samples_(self):
if self._n_samples_ is None:
raise RuntimeError("The estimator must be fitted before accessing this attribute.")

# we are not train/test splitting, then
# we are not train/test splitting, then
if not self.train_test_split:
return [(np.arange(self._n_samples_, dtype=int), np.array([], dtype=int)) for _ in range(len(self.estimator_.estimators_))]
return [
(np.arange(self._n_samples_, dtype=int), np.array([], dtype=int))
for _ in range(len(self.estimator_.estimators_))
]

# Stratifier uses a cached _y attribute if available
stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None
Expand Down Expand Up @@ -310,9 +313,11 @@ def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike =
f"y must have type {self._type_of_target_}, got {type_of_target(y)}. "
f"If running on a new dataset, call the 'reset' method."
)

if not self.train_test_split and not isinstance(self.estimator, HonestForestClassifier):
raise RuntimeError(f'Train test split must occur if not using honest forest classifier.')
raise RuntimeError(
"Train test split must occur if not using honest forest classifier."
)

return X, y, covariate_index

Expand Down Expand Up @@ -807,21 +812,24 @@ def _statistic(
if self.train_test_split:
# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test)
delayed(_parallel_predict_proba)(
estimator.estimators_[idx].predict, X, indices_test
)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)
else:
all_indices = np.arange(self._n_samples_, dtype=int)

# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test)
for idx, honest_tree in enumerate(estimator.estimators_)
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, all_indices)
for idx in range(len(estimator.estimators_))
)
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)
for itree, proba in enumerate(all_proba):
posterior_arr[itree, ...] = proba.reshape(-1, estimator.n_outputs_)

# determine if there are any nans in the final posterior array, when
# averaged over the trees
Expand Down Expand Up @@ -1050,29 +1058,37 @@ def _statistic(
y_train = y_train.ravel()
estimator.fit(X_train, y_train)

# construct posterior array for all trees (n_trees, n_samples_test, n_outputs)
# for itree, tree in enumerate(estimator.estimators_):
# if predict_posteriors:
# # XXX: currently assumes n_outputs_ == 1
# posterior_arr[itree, indices_test, ...] = tree.predict_proba(X_test).reshape(
# -1, tree.n_classes_
# )
# else:
# posterior_arr[itree, indices_test, ...] = tree.predict(X_test).reshape(
# -1, tree.n_outputs_
# )

# set variables to compute metric
samples = indices_test

# list of tree outputs. Each tree output is (n_samples, n_outputs), or (n_samples,)
if predict_posteriors:
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(
estimator.estimators_[idx].predict_proba, X, indices_test
# all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
# delayed(_parallel_predict_proba)(
# estimator.estimators_[idx].predict_proba, X, indices_test
# )
# for idx, (_, indices_test) in enumerate(self.train_test_samples_)
# )

# TODO: probably a more elegant way of doing this
if self.train_test_split:
# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(
estimator.estimators_[idx].predict_proba, X, indices_test
)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
else:
all_indices = np.arange(self._n_samples_, dtype=int)

# accumulate the predictions across all trees
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(
estimator.estimators_[idx].predict_proba, X, all_indices
)
for idx in range(len(estimator.estimators_))
)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
else:
all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(
delayed(_parallel_predict_proba)(
Expand All @@ -1084,7 +1100,12 @@ def _statistic(
_, indices_test = est_indices

if predict_posteriors:
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_classes_)
if self.train_test_split:
posterior_arr[itree, indices_test, ...] = proba.reshape(
-1, estimator.n_classes_
)
else:
posterior_arr[itree, ...] = proba.reshape(-1, estimator.n_classes_)
else:
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)

Expand Down
54 changes: 54 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,57 @@ def test_small_dataset_dependent(seed):

stat, pvalue = clf.test(X, y, metric="mi")
assert pvalue <= 0.05


@flaky(max_runs=3)
def test_no_traintest_split():
n_samples = 500
n_features = 5
rng = np.random.default_rng(seed)

X = rng.uniform(size=(n_samples, n_features))
X = rng.uniform(size=(n_samples // 2, n_features))
X2 = X * 2
X = np.vstack([X, X2])
y = np.vstack(
[np.zeros((n_samples // 2, 1)), np.ones((n_samples // 2, 1))]
) # Binary classification

clf = FeatureImportanceForestClassifier(
estimator=HonestForestClassifier(
n_estimators=50,
max_features=n_features,
random_state=seed,
n_jobs=1,
honest_fraction=0.5,
),
test_size=0.2,
train_test_split=False,
permute_forest_fraction=None,
sample_dataset_per_tree=False,
)
stat, pvalue = clf.test(X, y, covariate_index=[1, 2], metric="mi")

# since no train-test split, the training is all the data and the testing is none of the data
assert_array_equal(clf.train_test_samples_[0][0], np.arange(n_samples))
assert_array_equal(clf.train_test_samples_[0][1], np.array([]))

assert ~np.isnan(pvalue)
assert ~np.isnan(stat)
assert pvalue <= 0.05, f"{pvalue}"

stat, pvalue = clf.test(X, y, metric="mi")
assert pvalue <= 0.05, f"{pvalue}"

X = rng.uniform(size=(n_samples, n_features))
y = rng.integers(0, 2, size=n_samples) # Binary classification
clf.reset()

stat, pvalue = clf.test(X, y, metric="mi")
assert_almost_equal(stat, 0.0, decimal=1)
assert pvalue > 0.05, f"{pvalue}"

stat, pvalue = clf.test(X, y, covariate_index=[1, 2], metric="mi")
assert ~np.isnan(pvalue)
assert ~np.isnan(stat)
assert pvalue > 0.05, f"{pvalue}"

0 comments on commit 739c7be

Please sign in to comment.