diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 3827359b9162e..b5ee64b6e708c 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -595,42 +595,18 @@ def fit(self, X, y, sample_weight=None, classes=None): # would have got if we hadn't used a warm_start. random_state.randint(MAX_INT, size=len(self.estimators_)) - trees = [ - self._make_estimator(append=False, random_state=random_state) - for i in range(n_more_estimators) - ] - - # Parallel loop: we prefer the threading backend as the Cython code - # for fitting the trees is internally releasing the Python GIL - # making threading more efficient than multiprocessing in - # that case. However, for joblib 0.12+ we respect any - # parallel_backend contexts set at a higher level, - # since correctness does not rely on using threads. - trees = Parallel( - n_jobs=self.n_jobs, - verbose=self.verbose, - prefer="threads", - )( - delayed(_parallel_build_trees)( - t, - self.bootstrap, - X, - y, - sample_weight, - i, - len(trees), - verbose=self.verbose, - class_weight=self.class_weight, - n_samples_bootstrap=n_samples_bootstrap, - missing_values_in_feature_mask=missing_values_in_feature_mask, - classes=classes, - ) - for i, t in enumerate(trees) + # construct the trees in parallel + self._construct_trees( + X, + y, + sample_weight, + random_state, + n_samples_bootstrap, + missing_values_in_feature_mask, + classes, + n_more_estimators, ) - # Collect newly grown trees - self.estimators_.extend(trees) - if self.oob_score and ( n_more_estimators > 0 or not hasattr(self, "oob_score_") ): @@ -664,6 +640,53 @@ def fit(self, X, y, sample_weight=None, classes=None): return self + def _construct_trees( + self, + X, + y, + sample_weight, + random_state, + n_samples_bootstrap, + missing_values_in_feature_mask, + classes, + n_more_estimators, + ): + trees = [ + self._make_estimator(append=False, random_state=random_state) + for i in range(n_more_estimators) + ] + + # Parallel loop: we prefer the threading backend as the Cython code + # for fitting the trees is internally releasing the Python GIL + # making threading more efficient than multiprocessing in + # that case. However, for joblib 0.12+ we respect any + # parallel_backend contexts set at a higher level, + # since correctness does not rely on using threads. + trees = Parallel( + n_jobs=self.n_jobs, + verbose=self.verbose, + prefer="threads", + )( + delayed(_parallel_build_trees)( + t, + self.bootstrap, + X, + y, + sample_weight, + i, + len(trees), + verbose=self.verbose, + class_weight=self.class_weight, + n_samples_bootstrap=n_samples_bootstrap, + missing_values_in_feature_mask=missing_values_in_feature_mask, + classes=classes, + ) + for i, t in enumerate(trees) + ) + + # Collect newly grown trees + self.estimators_.extend(trees) + @abstractmethod def _set_oob_score_and_attributes(self, X, y, scoring_function=None): """Compute and set the OOB score and attributes.