diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 84a41aff1174c..6511c8192889e 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -413,7 +413,7 @@ def _fit( min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight) # build the actual tree now with the parameters - self._build_tree( + self = self._build_tree( X=X, y=y, sample_weight=sample_weight, @@ -573,6 +573,7 @@ def _build_tree( self.classes_ = self.classes_[0] self._prune_tree() + return self def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities)."""