Skip to content

Commit

Permalink
Fix update tree node
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Apr 2, 2024
1 parent a52ec74 commit 750573c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
12 changes: 12 additions & 0 deletions sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ cdef class BaseTree:
cdef int _resize(self, intp_t capacity) except -1 nogil
cdef int _resize_c(self, intp_t capacity=*) except -1 nogil

cdef int _update_node(
self,
intp_t parent,
bint is_left,
bint is_leaf,
SplitRecord* split_node,
float64_t impurity,
intp_t n_node_samples,
float64_t weighted_n_node_samples,
unsigned char missing_go_to_left
) except -1 nogil

# Python API methods: These are methods exposed to Python
cpdef cnp.ndarray apply(self, object X)
cdef cnp.ndarray _apply_dense(self, object X)
Expand Down
52 changes: 48 additions & 4 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
(split.improvement + EPSILON <
min_impurity_decrease))

node_id = tree._add_node(parent, is_left, is_leaf, split_ptr,
parent_record.impurity,
n_node_samples, weighted_n_node_samples,
split.missing_go_to_left)
node_id = tree._update_node(parent, is_left, is_leaf, split_ptr,
parent_record.impurity,
n_node_samples, weighted_n_node_samples,
split.missing_go_to_left)

if node_id == INTPTR_MAX:
rc = -1
Expand Down Expand Up @@ -1175,6 +1175,50 @@ cdef class BaseTree:

return node_id

cdef inline int _update_node(
self,
intp_t parent,
bint is_left,
bint is_leaf,
SplitRecord* split_node,
float64_t impurity,
intp_t n_node_samples,
float64_t weighted_n_node_samples,
unsigned char missing_go_to_left
) except -1 nogil:
"""Update a node on the tree.
The updated node remains on the same position.
Returns (intp_t)(-1) on error.
"""
cdef intp_t node_id
if is_left:
node_id = self.nodes[parent].left_child
else:
node_id = self.nodes[parent].right_child

if node_id >= self.capacity:
if self._resize_c() != 0:
return INTPTR_MAX

cdef Node* node = &self.nodes[node_id]
node.impurity = impurity
node.n_node_samples = n_node_samples
node.weighted_n_node_samples = weighted_n_node_samples

if is_leaf:
if self._set_leaf_node(split_node, node, node_id) != 1:
with gil:
raise RuntimeError
else:
if self._set_split_node(split_node, node, node_id) != 1:
with gil:
raise RuntimeError
node.missing_go_to_left = missing_go_to_left

return node_id

cpdef cnp.ndarray apply(self, object X):
"""Finds the terminal region (=leaf node) for each sample in X."""
if issparse(X):
Expand Down

0 comments on commit 750573c

Please sign in to comment.