Skip to content

Commit

Permalink
commented changes to tree
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelCarliles3 committed Dec 6, 2024
1 parent f655401 commit 877a822
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 57 deletions.
7 changes: 7 additions & 0 deletions sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ cdef extern from "<stack>" namespace "std" nogil:
void push(T&) except + # Raise c++ exception for bad_alloc -> MemoryError
T& top()

# A large portion of the tree build function was duplicated almost verbatim in the
# neurodata fork of sklearn. We refactor that out into its own function, and it's
# most convenient to encapsulate all the tree build state into its own env struct.
cdef enum TreeBuildStatus:
OK = 0
EXCEPTION_OR_MEMORY_ERROR = -1
Expand Down Expand Up @@ -113,6 +116,9 @@ cdef struct BuildEnv:

ParentInfo parent_record


# We add tree build events to notify interested parties of tree build state.
# Only current relevant events are implemented.
cdef enum TreeBuildEvent:
ADD_NODE = 1
UPDATE_NODE = 2
Expand Down Expand Up @@ -263,6 +269,7 @@ cdef class TreeBuilder:

cdef unsigned char store_leaf_values # Whether to store leaf values

# event broker for distributing tree build events
cdef EventBroker event_broker


Expand Down
65 changes: 8 additions & 57 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):

cdef void _build_body(self, EventBroker broker, Tree tree, Splitter splitter, BuildEnv* e, bint update) noexcept nogil:
cdef TreeBuildEvent evt

# payloads for different tree build events
cdef TreeBuildSetActiveParentEventData parent_event_data
cdef TreeBuildAddNodeEventData add_update_node_data

#with gil:
# print("")
# print("_build_body")

while not e.target_stack.empty():
e.stack_record = e.target_stack.top()
e.target_stack.pop()
Expand All @@ -295,15 +293,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
parent_event_data.parent_node_id = e.stack_record.parent
parent_event_data.child_is_left = e.stack_record.is_left

#with gil:
# print(f"start {e.start}")
# print(f"end {e.end}")
# print(f"parent {<int>e.parent}")
# print(f"is_left {e.is_left}")
# print(f"n_node_samples {e.n_node_samples}")
# print(f"parent_node_id {parent_event_data.parent_node_id}")
# print(f"child_is_left {parent_event_data.child_is_left}")

# tree build state is kind of weird as implemented because
# the child node id is assigned after child node creation, and all
# situational awareness during creation is referenced to the parent node.
# so we fire an event indicating the current active parent.
if not broker.fire_event(TreeBuildEvent.SET_ACTIVE_PARENT, &parent_event_data):
e.rc = TreeBuildStatus.EVENT_ERROR
break
Expand All @@ -315,29 +308,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
e.n_node_samples < 2 * e.min_samples_leaf or
e.weighted_n_node_samples < 2 * e.min_weight_leaf)

#with gil:
# print("")
# print(f"*** IS_LEAF ***")
# print(f"is_leaf = {e.is_leaf}")
# print(f"depth = {e.depth}")
# print(f"max_depth = {e.max_depth}")
# print(f"n_node_samples = {e.n_node_samples}")
# print(f"min_samples_split = {e.min_samples_split}")
# print(f"min_samples_leaf = {e.min_samples_leaf}")
# print(f"weighted_n_node_samples = {e.weighted_n_node_samples}")
# print(f"min_weight_leaf = {e.min_weight_leaf}")

if e.first:
e.parent_record.impurity = splitter.node_impurity()
e.first = 0

# impurity == 0 with tolerance due to rounding errors
e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON

#with gil:
# print(f"is_leaf 2 = {e.is_leaf}")
# print(f"parent_record.impurity = {e.parent_record.impurity}")

add_update_node_data.parent_node_id = e.parent
add_update_node_data.is_left = e.is_left
add_update_node_data.feature = -1
Expand All @@ -349,9 +326,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
e.split,
)

#with gil:
# print("_build_body checkpoint 1")

# If EPSILON=0 in the below comparison, float precision
# issues stop splitting, producing trees that are
# dissimilar to v0.18
Expand All @@ -363,14 +337,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
add_update_node_data.feature = e.split.feature
add_update_node_data.split_point = e.split.threshold

#with gil:
# print("_build_body checkpoint 2")
# print(f"is_leaf 3 = {e.is_leaf}")
# print(f"split.pos = {e.split.pos}")
# print(f"end = {e.end}")
# print(f"split.improvement = {e.split.improvement}")
# print(f"min_impurity_decrease = {e.min_impurity_decrease}")
# print(f"feature = {e.split.feature}")

if update == 1:
e.node_id = tree._update_node(
Expand All @@ -387,29 +353,17 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
)
evt = TreeBuildEvent.ADD_NODE

#with gil:
# print("_build_body checkpoint 3")

if e.node_id == INTPTR_MAX:
#with gil:
# print("_build_body checkpoint 3.25")
e.rc = TreeBuildStatus.EXCEPTION_OR_MEMORY_ERROR
break

#with gil:
# print("_build_body checkpoint 3.5")

add_update_node_data.node_id = e.node_id
add_update_node_data.is_leaf = e.is_leaf

#with gil:
# print("_build_body checkpoint 3.6")

# now that all relevant information has been accumulated,
# notify interested parties that a node has been added/updated
broker.fire_event(evt, &add_update_node_data)

#with gil:
# print("_build_body checkpoint 4")

# Store value for all nodes, to facilitate tree/model
# inspection and interpretation
splitter.node_value(tree.value + e.node_id * tree.value_stride)
Expand All @@ -420,9 +374,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
e.parent_record.upper_bound
)

#with gil:
# print("_build_body checkpoint 5")

if not e.is_leaf:
if (
not splitter.with_monotonic_cst or
Expand Down

0 comments on commit 877a822

Please sign in to comment.