Skip to content

Commit

Permalink
Factor out construct trees API
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Feb 23, 2024
1 parent d48716a commit 33039e2
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,42 +595,18 @@ def fit(self, X, y, sample_weight=None, classes=None):
# would have got if we hadn't used a warm_start.
random_state.randint(MAX_INT, size=len(self.estimators_))

trees = [
self._make_estimator(append=False, random_state=random_state)
for i in range(n_more_estimators)
]

# Parallel loop: we prefer the threading backend as the Cython code
# for fitting the trees is internally releasing the Python GIL
# making threading more efficient than multiprocessing in
# that case. However, for joblib 0.12+ we respect any
# parallel_backend contexts set at a higher level,
# since correctness does not rely on using threads.
trees = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose,
prefer="threads",
)(
delayed(_parallel_build_trees)(
t,
self.bootstrap,
X,
y,
sample_weight,
i,
len(trees),
verbose=self.verbose,
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
missing_values_in_feature_mask=missing_values_in_feature_mask,
classes=classes,
)
for i, t in enumerate(trees)
# construct the trees in parallel
self._construct_trees(
X,
y,
sample_weight,
random_state,
n_samples_bootstrap,
missing_values_in_feature_mask,
classes,
n_more_estimators,
)

# Collect newly grown trees
self.estimators_.extend(trees)

if self.oob_score and (
n_more_estimators > 0 or not hasattr(self, "oob_score_")
):
Expand Down Expand Up @@ -664,6 +640,53 @@ def fit(self, X, y, sample_weight=None, classes=None):

return self

def _construct_trees(
self,
X,
y,
sample_weight,
random_state,
n_samples_bootstrap,
missing_values_in_feature_mask,
classes,
n_more_estimators,
):
trees = [
self._make_estimator(append=False, random_state=random_state)
for i in range(n_more_estimators)
]

# Parallel loop: we prefer the threading backend as the Cython code
# for fitting the trees is internally releasing the Python GIL
# making threading more efficient than multiprocessing in
# that case. However, for joblib 0.12+ we respect any
# parallel_backend contexts set at a higher level,
# since correctness does not rely on using threads.
trees = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose,
prefer="threads",
)(
delayed(_parallel_build_trees)(
t,
self.bootstrap,
X,
y,
sample_weight,
i,
len(trees),
verbose=self.verbose,
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
missing_values_in_feature_mask=missing_values_in_feature_mask,
classes=classes,
)
for i, t in enumerate(trees)
)

# Collect newly grown trees
self.estimators_.extend(trees)

@abstractmethod
def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
"""Compute and set the OOB score and attributes.
Expand Down

0 comments on commit 33039e2

Please sign in to comment.