Skip to content

Commit

Permalink
Try again for partial fit
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Apr 3, 2024
1 parent e020253 commit 775f0b7
Showing 1 changed file with 70 additions and 66 deletions.
136 changes: 70 additions & 66 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def _build_tree(
random_state : int
Random seed.
"""

n_samples = X.shape[0]

# Build tree
Expand Down Expand Up @@ -576,6 +575,75 @@ def _build_tree(
self._prune_tree()
return self

def _update_tree(self, X, y, sample_weight):
# Update tree
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
min_samples_split = self.min_samples_split_
min_samples_leaf = self.min_samples_leaf_
min_weight_leaf = self.min_weight_leaf_
# set decision-tree model parameters
max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth

monotonic_cst = self.monotonic_cst_

# Build tree
# Note: this reconstructs the builder with the same state it had during the
# initial fit. This is necessary because the builder is not saved as part
# of the class, and thus the state may be lost if pickled/unpickled.
n_samples = X.shape[0]
criterion = self.criterion
if not isinstance(criterion, BaseCriterion):
if is_classifier(self):
criterion = CRITERIA_CLF[self.criterion](
self.n_outputs_, self._n_classes_
)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
else:
# Make a deepcopy in case the criterion has mutable attributes that
# might be shared and modified concurrently during parallel fitting
criterion = copy.deepcopy(criterion)

random_state = check_random_state(self.random_state)

SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
splitter = SPLITTERS[self.splitter](
criterion,
self.max_features_,
min_samples_leaf,
min_weight_leaf,
random_state,
monotonic_cst,
)

# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
builder = DepthFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
self.min_impurity_decrease,
self.store_leaf_values,
)
else:
builder = BestFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
max_leaf_nodes,
self.min_impurity_decrease,
self.store_leaf_values,
)
builder.initialize_node_queue(self.tree_, X, y, sample_weight)
builder.build(self.tree_, X, y, sample_weight)

self._prune_tree()
return self

def _validate_X_predict(self, X, check_input):
"""Validate the training data on predict (probabilities)."""
if check_input:
Expand Down Expand Up @@ -1375,71 +1443,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None):
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

# Update tree
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
min_samples_split = self.min_samples_split_
min_samples_leaf = self.min_samples_leaf_
min_weight_leaf = self.min_weight_leaf_
# set decision-tree model parameters
max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth

monotonic_cst = self.monotonic_cst_

# Build tree
# Note: this reconstructs the builder with the same state it had during the
# initial fit. This is necessary because the builder is not saved as part
# of the class, and thus the state may be lost if pickled/unpickled.
n_samples = X.shape[0]
criterion = self.criterion
if not isinstance(criterion, BaseCriterion):
if is_classifier(self):
criterion = CRITERIA_CLF[self.criterion](
self.n_outputs_, self._n_classes_
)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
else:
# Make a deepcopy in case the criterion has mutable attributes that
# might be shared and modified concurrently during parallel fitting
criterion = copy.deepcopy(criterion)

random_state = check_random_state(self.random_state)
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
splitter = SPLITTERS[self.splitter](
criterion,
self.max_features_,
min_samples_leaf,
min_weight_leaf,
random_state,
monotonic_cst,
)

# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
builder = DepthFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
self.min_impurity_decrease,
self.store_leaf_values,
)
else:
builder = BestFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
max_leaf_nodes,
self.min_impurity_decrease,
self.store_leaf_values,
)
builder.initialize_node_queue(self.tree_, X, y, sample_weight)
builder.build(self.tree_, X, y, sample_weight)

self._prune_tree()
self._update_tree(X, y, sample_weight)

return self

Expand Down

0 comments on commit 775f0b7

Please sign in to comment.