From 151cb2dc64f60d011b178b91d0defb81d90f72b3 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 13 Aug 2024 15:13:13 -0400 Subject: [PATCH] FEA Add `_build_pruned_tree` to tree.pxd file to allow cimports and a `_build_pruned_tree_py` to allow anyone to prune trees (#29590) Signed-off-by: Adam Li Co-authored-by: Thomas J. Fan --- sklearn/tree/_tree.pxd | 17 +++++++++++ sklearn/tree/_tree.pyx | 41 ++++++++++++++++++++++++++- sklearn/tree/_utils.pxd | 1 + sklearn/tree/tests/test_tree.py | 50 +++++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 831ca38a11148..2cadca4564a87 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -114,3 +114,20 @@ cdef class TreeBuilder: const float64_t[:, ::1] y, const float64_t[:] sample_weight, ) + + +# ============================================================================= +# Tree pruning +# ============================================================================= + +# The private function allows any external caller to prune the tree and return +# a new tree with the pruned nodes. The pruned tree is a new tree object. +# +# .. warning:: this function is not backwards compatible and may change without +# notice. +cdef void _build_pruned_tree( + Tree tree, # OUT + Tree orig_tree, + const uint8_t[:] leaves_in_subtree, + intp_t capacity +) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 43b7770131497..7e6946a718a81 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1872,7 +1872,7 @@ cdef struct BuildPrunedRecord: intp_t parent bint is_left -cdef _build_pruned_tree( +cdef void _build_pruned_tree( Tree tree, # OUT Tree orig_tree, const uint8_t[:] leaves_in_subtree, @@ -1931,6 +1931,15 @@ cdef _build_pruned_tree( is_leaf = leaves_in_subtree[orig_node_id] node = &orig_tree.nodes[orig_node_id] + # protect against an infinite loop as a runtime error, when leaves_in_subtree + # are improperly set where a node is not marked as a leaf, but is a node + # in the original tree. Thus, it violates the assumption that the node + # is a leaf in the pruned tree, or has a descendant that will be pruned. + if (not is_leaf and node.left_child == _TREE_LEAF + and node.right_child == _TREE_LEAF): + rc = -2 + break + new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, @@ -1960,3 +1969,33 @@ cdef _build_pruned_tree( tree.max_depth = max_depth_seen if rc == -1: raise MemoryError("pruning tree") + elif rc == -2: + raise ValueError( + "Node has reached a leaf in the original tree, but is not " + "marked as a leaf in the leaves_in_subtree mask." + ) + + +def _build_pruned_tree_py(Tree tree, Tree orig_tree, const uint8_t[:] leaves_in_subtree): + """Build a pruned tree. + + Build a pruned tree from the original tree by transforming the nodes in + ``leaves_in_subtree`` into leaves. + + Parameters + ---------- + tree : Tree + Location to place the pruned tree + orig_tree : Tree + Original tree + leaves_in_subtree : uint8_t ndarray, shape=(node_count, ) + Boolean mask for leaves to include in subtree. The array must have + the same size as the number of nodes in the original tree. + """ + if leaves_in_subtree.shape[0] != orig_tree.node_count: + raise ValueError( + f"The length of leaves_in_subtree {len(leaves_in_subtree)} must be " + f"equal to the number of nodes in the original tree {orig_tree.node_count}." + ) + + _build_pruned_tree(tree, orig_tree, leaves_in_subtree, orig_tree.node_count) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index de16cc65b32a9..bc1d7668187d7 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -8,6 +8,7 @@ from ._tree cimport Node from ..neighbors._quad_tree cimport Cell from ..utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t + cdef enum: # Max value for our rand_r replacement (near the bottom). # We don't use RAND_MAX because it's different across platforms and diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 60d864a73a790..5ef783de305d2 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -39,6 +39,7 @@ NODE_DTYPE, TREE_LEAF, TREE_UNDEFINED, + _build_pruned_tree_py, _check_n_classes, _check_node_ndarray, _check_value_ndarray, @@ -2783,3 +2784,52 @@ def test_classification_tree_missing_values_toy(): (tree.tree_.children_left == -1) & (tree.tree_.n_node_samples == 1) ) assert_allclose(tree.tree_.impurity[leaves_idx], 0.0) + + +def test_build_pruned_tree_py(): + """Test pruning a tree with the Python caller of the Cythonized prune tree.""" + tree = DecisionTreeClassifier(random_state=0, max_depth=1) + tree.fit(iris.data, iris.target) + + n_classes = np.atleast_1d(tree.n_classes_) + pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + + # only keep the root note + leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) + leave_in_subtree[0] = 1 + _build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree) + + assert tree.tree_.node_count == 3 + assert pruned_tree.node_count == 1 + with pytest.raises(AssertionError): + assert_array_equal(tree.tree_.value, pruned_tree.value) + assert_array_equal(tree.tree_.value[0], pruned_tree.value[0]) + + # now keep all the leaves + pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) + leave_in_subtree[1:] = 1 + + # Prune the tree + _build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree) + assert tree.tree_.node_count == 3 + assert pruned_tree.node_count == 3, pruned_tree.node_count + assert_array_equal(tree.tree_.value, pruned_tree.value) + + +def test_build_pruned_tree_infinite_loop(): + """Test pruning a tree does not result in an infinite loop.""" + + # Create a tree with root and two children + tree = DecisionTreeClassifier(random_state=0, max_depth=1) + tree.fit(iris.data, iris.target) + n_classes = np.atleast_1d(tree.n_classes_) + pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + + # only keeping one child as a leaf results in an improper tree + leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) + leave_in_subtree[1] = 1 + with pytest.raises( + ValueError, match="Node has reached a leaf in the original tree" + ): + _build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)