Skip to content

Commit

Permalink
UPdate and address permute forest fraction
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Oct 24, 2023
1 parent 0f27d01 commit e894708
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 67 deletions.
62 changes: 38 additions & 24 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,19 @@ def __init__(
random_state=None,
verbose=0,
test_size=0.2,
permute_per_tree=False,
sample_dataset_per_tree=False,
stratify=False,
sample_dataset_per_tree=False,
permute_forest_fraction=None,
):
self.estimator = estimator
self.random_state = random_state
self.verbose = verbose
self.test_size = test_size
self.permute_per_tree = permute_per_tree
self.stratify = stratify

# XXX: possibly removing these parameters
self.sample_dataset_per_tree = sample_dataset_per_tree
self.permute_forest_fraction = permute_forest_fraction
self.stratify = stratify

self.n_samples_test_ = None
self._n_samples_ = None
Expand Down Expand Up @@ -166,15 +166,33 @@ def _get_estimators_indices(self, stratifier=None, sample_separate=False):
# Get drawn indices along both sample and feature axes
rng = np.random.default_rng(self.estimator_.random_state)

if self.sample_dataset_per_tree:
if self.permute_forest_fraction is None:
permute_forest_fraction = 0.0
else:
permute_forest_fraction = self.permute_forest_fraction

# TODO: consolidate how we "sample/permute" per subset of the forest
if self.sample_dataset_per_tree or permute_forest_fraction > 0.0:
# sample random seeds
if self._seeds is None:
self._seeds = []
self._n_permutations = 0

for tree in self.estimator_.estimators_:
if tree.random_state is None:
self._seeds.append(rng.integers(low=0, high=np.iinfo(np.int32).max))
else:
self._seeds.append(tree.random_state)
num_trees_per_seed = max(
int(permute_forest_fraction * len(self.estimator_.estimators_)), 1
)
for tree_idx, tree in enumerate(self.estimator_.estimators_):
if tree_idx == 0 or tree_idx % num_trees_per_seed == 0:
if tree.random_state is None:
seed = rng.integers(low=0, high=np.iinfo(np.int32).max)
else:
seed = tree.random_state

self._n_permutations += 1
self._seeds.append(seed)

# now that we have the random seeds, we can sample the train/test indices
# deterministically
seeds = self._seeds

# if sample_separate:
Expand Down Expand Up @@ -632,7 +650,7 @@ def __init__(
random_state=None,
verbose=0,
test_size=0.2,
permute_per_tree=False,
# permute_per_tree=False,
sample_dataset_per_tree=False,
permute_forest_fraction=None,
):
Expand All @@ -641,7 +659,7 @@ def __init__(
random_state=random_state,
verbose=verbose,
test_size=test_size,
permute_per_tree=permute_per_tree,
# permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
permute_forest_fraction=permute_forest_fraction,
)
Expand Down Expand Up @@ -721,7 +739,7 @@ def _statistic(

# both sampling dataset per tree or permuting per tree requires us to bypass the
# sklearn API to fit each tree individually
if self.sample_dataset_per_tree or self.permute_per_tree:
if self.sample_dataset_per_tree or self.permute_forest_fraction:
# Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose, prefer="threads")(
# delayed(_parallel_build_trees_and_compute_posteriors)(
# estimator,
Expand All @@ -741,7 +759,7 @@ def _statistic(
# )
# )

if self.permute_per_tree and covariate_index is not None:
if self.permute_forest_fraction and covariate_index is not None:
random_states = [tree.random_state for tree in estimator.estimators_]
else:
random_states = [estimator.random_state] * len(estimator.estimators_)
Expand Down Expand Up @@ -799,9 +817,7 @@ def _statistic(
delayed(_parallel_predict_proba)(estimator.estimators_[idx].predict, X, indices_test)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
for itree, (proba, est_indices) in enumerate(
zip(all_proba, self.train_test_samples_)
):
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices
posterior_arr[itree, indices_test, ...] = proba.reshape(-1, estimator.n_outputs_)

Expand Down Expand Up @@ -923,7 +939,7 @@ def __init__(
random_state=None,
verbose=0,
test_size=0.2,
permute_per_tree=False,
# permute_per_tree=False,
sample_dataset_per_tree=False,
stratify=True,
permute_forest_fraction=None,
Expand All @@ -933,7 +949,7 @@ def __init__(
random_state=random_state,
verbose=verbose,
test_size=test_size,
permute_per_tree=permute_per_tree,
# permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
stratify=stratify,
permute_forest_fraction=permute_forest_fraction,
Expand Down Expand Up @@ -982,7 +998,7 @@ def _statistic(

# both sampling dataset per tree or permuting per tree requires us to bypass the
# sklearn API to fit each tree individually
if self.sample_dataset_per_tree or self.permute_per_tree:
if self.sample_dataset_per_tree or self.permute_forest_fraction:
# Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose, prefer="threads")(
# delayed(_parallel_build_trees_and_compute_posteriors)(
# estimator,
Expand All @@ -1001,7 +1017,7 @@ def _statistic(
# self._get_estimators_indices(sample_separate=True)
# )
# )
if self.permute_per_tree and covariate_index is not None:
if self.permute_forest_fraction and covariate_index is not None:
random_states = [tree.random_state for tree in estimator.estimators_]
else:
random_states = [estimator.random_state] * len(estimator.estimators_)
Expand Down Expand Up @@ -1074,9 +1090,7 @@ def _statistic(
)
for idx, (_, indices_test) in enumerate(self.train_test_samples_)
)
for itree, (proba, est_indices) in enumerate(
zip(all_proba, self.train_test_samples_)
):
for itree, (proba, est_indices) in enumerate(zip(all_proba, self.train_test_samples_)):
_, indices_test = est_indices

if predict_posteriors:
Expand Down
Loading

0 comments on commit e894708

Please sign in to comment.