diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 41d53b01ac276..9b11face3e6bf 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -69,6 +69,9 @@ cdef extern from "" namespace "std" nogil: void push(T&) except + # Raise c++ exception for bad_alloc -> MemoryError T& top() +# A large portion of the tree build function was duplicated almost verbatim in the +# neurodata fork of sklearn. We refactor that out into its own function, and it's +# most convenient to encapsulate all the tree build state into its own env struct. cdef enum TreeBuildStatus: OK = 0 EXCEPTION_OR_MEMORY_ERROR = -1 @@ -113,6 +116,9 @@ cdef struct BuildEnv: ParentInfo parent_record + +# We add tree build events to notify interested parties of tree build state. +# Only current relevant events are implemented. cdef enum TreeBuildEvent: ADD_NODE = 1 UPDATE_NODE = 2 @@ -263,6 +269,7 @@ cdef class TreeBuilder: cdef unsigned char store_leaf_values # Whether to store leaf values + # event broker for distributing tree build events cdef EventBroker event_broker diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index d9fcc8322ddcb..918bde971d426 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -269,13 +269,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef void _build_body(self, EventBroker broker, Tree tree, Splitter splitter, BuildEnv* e, bint update) noexcept nogil: cdef TreeBuildEvent evt + + # payloads for different tree build events cdef TreeBuildSetActiveParentEventData parent_event_data cdef TreeBuildAddNodeEventData add_update_node_data - #with gil: - # print("") - # print("_build_body") - while not e.target_stack.empty(): e.stack_record = e.target_stack.top() e.target_stack.pop() @@ -295,15 +293,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): parent_event_data.parent_node_id = e.stack_record.parent parent_event_data.child_is_left = e.stack_record.is_left - #with gil: - # print(f"start {e.start}") - # print(f"end {e.end}") - # print(f"parent {e.parent}") - # print(f"is_left {e.is_left}") - # print(f"n_node_samples {e.n_node_samples}") - # print(f"parent_node_id {parent_event_data.parent_node_id}") - # print(f"child_is_left {parent_event_data.child_is_left}") - + # tree build state is kind of weird as implemented because + # the child node id is assigned after child node creation, and all + # situational awareness during creation is referenced to the parent node. + # so we fire an event indicating the current active parent. if not broker.fire_event(TreeBuildEvent.SET_ACTIVE_PARENT, &parent_event_data): e.rc = TreeBuildStatus.EVENT_ERROR break @@ -315,18 +308,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): e.n_node_samples < 2 * e.min_samples_leaf or e.weighted_n_node_samples < 2 * e.min_weight_leaf) - #with gil: - # print("") - # print(f"*** IS_LEAF ***") - # print(f"is_leaf = {e.is_leaf}") - # print(f"depth = {e.depth}") - # print(f"max_depth = {e.max_depth}") - # print(f"n_node_samples = {e.n_node_samples}") - # print(f"min_samples_split = {e.min_samples_split}") - # print(f"min_samples_leaf = {e.min_samples_leaf}") - # print(f"weighted_n_node_samples = {e.weighted_n_node_samples}") - # print(f"min_weight_leaf = {e.min_weight_leaf}") - if e.first: e.parent_record.impurity = splitter.node_impurity() e.first = 0 @@ -334,10 +315,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # impurity == 0 with tolerance due to rounding errors e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON - #with gil: - # print(f"is_leaf 2 = {e.is_leaf}") - # print(f"parent_record.impurity = {e.parent_record.impurity}") - add_update_node_data.parent_node_id = e.parent add_update_node_data.is_left = e.is_left add_update_node_data.feature = -1 @@ -349,9 +326,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): e.split, ) - #with gil: - # print("_build_body checkpoint 1") - # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -363,14 +337,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): add_update_node_data.feature = e.split.feature add_update_node_data.split_point = e.split.threshold - #with gil: - # print("_build_body checkpoint 2") - # print(f"is_leaf 3 = {e.is_leaf}") - # print(f"split.pos = {e.split.pos}") - # print(f"end = {e.end}") - # print(f"split.improvement = {e.split.improvement}") - # print(f"min_impurity_decrease = {e.min_impurity_decrease}") - # print(f"feature = {e.split.feature}") if update == 1: e.node_id = tree._update_node( @@ -387,29 +353,17 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): ) evt = TreeBuildEvent.ADD_NODE - #with gil: - # print("_build_body checkpoint 3") - if e.node_id == INTPTR_MAX: - #with gil: - # print("_build_body checkpoint 3.25") e.rc = TreeBuildStatus.EXCEPTION_OR_MEMORY_ERROR break - #with gil: - # print("_build_body checkpoint 3.5") - add_update_node_data.node_id = e.node_id add_update_node_data.is_leaf = e.is_leaf - #with gil: - # print("_build_body checkpoint 3.6") - + # now that all relevant information has been accumulated, + # notify interested parties that a node has been added/updated broker.fire_event(evt, &add_update_node_data) - #with gil: - # print("_build_body checkpoint 4") - # Store value for all nodes, to facilitate tree/model # inspection and interpretation splitter.node_value(tree.value + e.node_id * tree.value_stride) @@ -420,9 +374,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): e.parent_record.upper_bound ) - #with gil: - # print("_build_body checkpoint 5") - if not e.is_leaf: if ( not splitter.with_monotonic_cst or