Skip to content

Commit

Permalink
FEA Add _build_pruned_tree to tree.pxd file to allow cimports and a…
Browse files Browse the repository at this point in the history
… `_build_pruned_tree_py` to allow anyone to prune trees (scikit-learn#29590)

Signed-off-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
adam2392 and thomasjpfan authored Aug 13, 2024
1 parent 8392e92 commit 151cb2d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
17 changes: 17 additions & 0 deletions sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,20 @@ cdef class TreeBuilder:
const float64_t[:, ::1] y,
const float64_t[:] sample_weight,
)


# =============================================================================
# Tree pruning
# =============================================================================

# The private function allows any external caller to prune the tree and return
# a new tree with the pruned nodes. The pruned tree is a new tree object.
#
# .. warning:: this function is not backwards compatible and may change without
# notice.
cdef void _build_pruned_tree(
Tree tree, # OUT
Tree orig_tree,
const uint8_t[:] leaves_in_subtree,
intp_t capacity
)
41 changes: 40 additions & 1 deletion sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,7 @@ cdef struct BuildPrunedRecord:
intp_t parent
bint is_left

cdef _build_pruned_tree(
cdef void _build_pruned_tree(
Tree tree, # OUT
Tree orig_tree,
const uint8_t[:] leaves_in_subtree,
Expand Down Expand Up @@ -1931,6 +1931,15 @@ cdef _build_pruned_tree(
is_leaf = leaves_in_subtree[orig_node_id]
node = &orig_tree.nodes[orig_node_id]

# protect against an infinite loop as a runtime error, when leaves_in_subtree
# are improperly set where a node is not marked as a leaf, but is a node
# in the original tree. Thus, it violates the assumption that the node
# is a leaf in the pruned tree, or has a descendant that will be pruned.
if (not is_leaf and node.left_child == _TREE_LEAF
and node.right_child == _TREE_LEAF):
rc = -2
break

new_node_id = tree._add_node(
parent, is_left, is_leaf, node.feature, node.threshold,
node.impurity, node.n_node_samples,
Expand Down Expand Up @@ -1960,3 +1969,33 @@ cdef _build_pruned_tree(
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError("pruning tree")
elif rc == -2:
raise ValueError(
"Node has reached a leaf in the original tree, but is not "
"marked as a leaf in the leaves_in_subtree mask."
)


def _build_pruned_tree_py(Tree tree, Tree orig_tree, const uint8_t[:] leaves_in_subtree):
"""Build a pruned tree.
Build a pruned tree from the original tree by transforming the nodes in
``leaves_in_subtree`` into leaves.
Parameters
----------
tree : Tree
Location to place the pruned tree
orig_tree : Tree
Original tree
leaves_in_subtree : uint8_t ndarray, shape=(node_count, )
Boolean mask for leaves to include in subtree. The array must have
the same size as the number of nodes in the original tree.
"""
if leaves_in_subtree.shape[0] != orig_tree.node_count:
raise ValueError(
f"The length of leaves_in_subtree {len(leaves_in_subtree)} must be "
f"equal to the number of nodes in the original tree {orig_tree.node_count}."
)

_build_pruned_tree(tree, orig_tree, leaves_in_subtree, orig_tree.node_count)
1 change: 1 addition & 0 deletions sklearn/tree/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from ._tree cimport Node
from ..neighbors._quad_tree cimport Cell
from ..utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t


cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
Expand Down
50 changes: 50 additions & 0 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
NODE_DTYPE,
TREE_LEAF,
TREE_UNDEFINED,
_build_pruned_tree_py,
_check_n_classes,
_check_node_ndarray,
_check_value_ndarray,
Expand Down Expand Up @@ -2783,3 +2784,52 @@ def test_classification_tree_missing_values_toy():
(tree.tree_.children_left == -1) & (tree.tree_.n_node_samples == 1)
)
assert_allclose(tree.tree_.impurity[leaves_idx], 0.0)


def test_build_pruned_tree_py():
"""Test pruning a tree with the Python caller of the Cythonized prune tree."""
tree = DecisionTreeClassifier(random_state=0, max_depth=1)
tree.fit(iris.data, iris.target)

n_classes = np.atleast_1d(tree.n_classes_)
pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_)

# only keep the root note
leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8)
leave_in_subtree[0] = 1
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)

assert tree.tree_.node_count == 3
assert pruned_tree.node_count == 1
with pytest.raises(AssertionError):
assert_array_equal(tree.tree_.value, pruned_tree.value)
assert_array_equal(tree.tree_.value[0], pruned_tree.value[0])

# now keep all the leaves
pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_)
leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8)
leave_in_subtree[1:] = 1

# Prune the tree
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)
assert tree.tree_.node_count == 3
assert pruned_tree.node_count == 3, pruned_tree.node_count
assert_array_equal(tree.tree_.value, pruned_tree.value)


def test_build_pruned_tree_infinite_loop():
"""Test pruning a tree does not result in an infinite loop."""

# Create a tree with root and two children
tree = DecisionTreeClassifier(random_state=0, max_depth=1)
tree.fit(iris.data, iris.target)
n_classes = np.atleast_1d(tree.n_classes_)
pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_)

# only keeping one child as a leaf results in an improper tree
leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8)
leave_in_subtree[1] = 1
with pytest.raises(
ValueError, match="Node has reached a leaf in the original tree"
):
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)

0 comments on commit 151cb2d

Please sign in to comment.