Skip to content

Commit

Permalink
commented changes to splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelCarliles3 committed Dec 6, 2024
1 parent 5291fb1 commit f655401
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 107 deletions.
4 changes: 0 additions & 4 deletions sklearn/tree/_events.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ cdef class EventBroker:
cdef bint fire_event(self, EventType event_type, EventData event_data) noexcept nogil:
cdef bint result = True

#with gil:
# print(f"firing event {event_type}")
# print(f"listeners.size = {self.listeners.size()}")

if event_type < self.listeners.size():
for l in self.listeners[event_type]:
result = result and l.f(event_type, l.e, event_data)
Expand Down
29 changes: 18 additions & 11 deletions sklearn/tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,14 @@ cdef struct NodeSplitEventData:
intp_t feature
float64_t threshold

# NICE IDEAS THAT DON'T APPEAR POSSIBLE
# - accessing elements of a memory view of cython extension types in a nogil block/function
# - storing cython extension types in cpp vectors
#
# despite the fact that we can access scalar extension type properties in such a context,
# as for instance node_split_best does with Criterion and Partition,
# and we can access the elements of a memory view of primitive types in such a context
#
# SO WHERE DOES THAT LEAVE US
# - we can transform these into cpp vectors of structs
# and with some minor casting irritations everything else works ok
# We wish to generalize Splitter so that arbitrary split rejection criteria can be
# passed in dynamically at construction. The natural way to want to do this is to
# pass in a list of lambdas, but as we are in cython, this is not so straightforward.
# We want the convience of being able to pass them in as a python list, and while it
# would be nice to receive them as a memoryview, this is quite a nuisance with
# cython extension types, so we do cpp vector instead. We do the same closure struct
# pattern for execution speed, but they need to be wrapped in cython extension types
# both for convenience and to go in python list.
ctypedef void* SplitConditionEnv
ctypedef bint (*SplitConditionFunction)(
Splitter splitter,
Expand Down Expand Up @@ -79,6 +76,12 @@ cdef struct SplitRecord:
unsigned char missing_go_to_left # Controls if missing values go to the left node.
intp_t n_missing # Number of missing values for the feature being split on


# In the neurodata fork of sklearn there was a hack added where SplitRecords are
# created which queries splitter for pointer size and does an inline malloc. This
# is to accommodate the ability to create extended SplitRecord types in Splitter
# subclasses. We refactor that into a factory method again implemented as a closure
# struct.
ctypedef void* SplitRecordFactoryEnv
ctypedef SplitRecord* (*SplitRecordFactory)(SplitRecordFactoryEnv env) except NULL nogil

Expand Down Expand Up @@ -168,9 +171,13 @@ cdef class Splitter(BaseSplitter):
cdef SplitCondition min_weight_leaf_condition
cdef SplitCondition monotonic_constraint_condition

# split rejection criteria checked before split selection
cdef vector[SplitConditionClosure] presplit_conditions

# split rejection criteria checked after split selection
cdef vector[SplitConditionClosure] postsplit_conditions

# event broker for handling splitter events
cdef EventBroker event_broker

cdef int init(
Expand Down
108 changes: 16 additions & 92 deletions sklearn/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import numpy as np
cdef float64_t INFINITY = np.inf


# we refactor the inline min sample leaf split rejection criterion
# into our injectable SplitCondition pattern
cdef bint min_sample_leaf_condition(
Splitter splitter,
intp_t split_feature,
Expand Down Expand Up @@ -66,6 +68,9 @@ cdef class MinSamplesLeafCondition(SplitCondition):
self.c.f = min_sample_leaf_condition
self.c.e = NULL # min_samples is stored in splitter, which is already passed to f


# we refactor the inline min weight leaf split rejection criterion
# into our injectable SplitCondition pattern
cdef bint min_weight_leaf_condition(
Splitter splitter,
intp_t split_feature,
Expand All @@ -91,6 +96,9 @@ cdef class MinWeightLeafCondition(SplitCondition):
self.c.f = min_weight_leaf_condition
self.c.e = NULL # min_weight_leaf is stored in splitter, which is already passed to f


# we refactor the inline monotonic constraint split rejection criterion
# into our injectable SplitCondition pattern
cdef bint monotonic_constraint_condition(
Splitter splitter,
intp_t split_feature,
Expand Down Expand Up @@ -131,6 +139,7 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
self.missing_go_to_left = False
self.n_missing = 0

# the default SplitRecord factory method simply mallocs a SplitRecord
cdef SplitRecord* _base_split_record_factory(SplitRecordFactoryEnv env) except NULL nogil:
return <SplitRecord*>malloc(sizeof(SplitRecord));

Expand Down Expand Up @@ -281,20 +290,6 @@ cdef class Splitter(BaseSplitter):
self.min_samples_leaf_condition = MinSamplesLeafCondition()
self.min_weight_leaf_condition = MinWeightLeafCondition()

#self.presplit_conditions.resize(
# (len(presplit_conditions) if presplit_conditions is not None else 0)
# + (2 if self.with_monotonic_cst else 1)
#)
#self.postsplit_conditions.resize(
# (len(postsplit_conditions) if postsplit_conditions is not None else 0)
# + (2 if self.with_monotonic_cst else 1)
#)

#cdef int offset = 0
#self.presplit_conditions[offset] = self.min_samples_leaf_condition.c
#self.postsplit_conditions[offset] = self.min_weight_leaf_condition.c
#offset += 1

l_pre = [self.min_samples_leaf_condition]
l_post = [self.min_weight_leaf_condition]

Expand All @@ -306,16 +301,11 @@ cdef class Splitter(BaseSplitter):
#self.postsplit_conditions[offset] = self.monotonic_constraint_condition.c
#offset += 1

#cdef int i
if presplit_conditions is not None:
l_pre += presplit_conditions
#for i in range(len(presplit_conditions)):
# self.presplit_conditions[i + offset] = presplit_conditions[i].c

if postsplit_conditions is not None:
l_post += postsplit_conditions
#for i in range(len(postsplit_conditions)):
# self.postsplit_conditions[i + offset] = postsplit_conditions[i].c

self.presplit_conditions.resize(0)
self.add_presplit_conditions(l_pre)
Expand Down Expand Up @@ -595,10 +585,6 @@ cdef inline intp_t node_split_best(
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
#with gil:
# print("")
# print("in node_split_best")

cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
cdef bint with_monotonic_cst = splitter.with_monotonic_cst

Expand Down Expand Up @@ -648,19 +634,14 @@ cdef inline intp_t node_split_best(

cdef bint conditions_hold = True

# payloads for different node events
cdef NodeSortFeatureEventData sort_event_data
cdef NodeSplitEventData split_event_data

#with gil:
# print("checkpoint 1")

_init_split(&best_split, end)

partitioner.init_node_split(start, end)

#with gil:
# print("checkpoint 2")

# Sample up to max_features without replacement using a
# Fisher-Yates-based algorithm (using the local variables `f_i` and
# `f_j` to compute a permutation of the `features` array).
Expand Down Expand Up @@ -706,6 +687,7 @@ cdef inline intp_t node_split_best(
current_split.feature = features[f_j]
partitioner.sort_samples_and_feature_values(current_split.feature)

# notify any interested parties which feature we're investingating splits for now
sort_event_data.feature = current_split.feature
splitter.event_broker.fire_event(NodeSplitEvent.SORT_FEATURE, &sort_event_data)

Expand Down Expand Up @@ -741,46 +723,28 @@ cdef inline intp_t node_split_best(
n_searches = 2 if has_missing else 1

for i in range(n_searches):
#with gil:
# print(f"search {i}")

missing_go_to_left = i == 1
criterion.missing_go_to_left = missing_go_to_left
criterion.reset()

p = start

while p < end_non_missing:
#with gil:
# print("")
# print("_node_split_best checkpoint 1")

partitioner.next_p(&p_prev, &p)

#with gil:
# print("checkpoint 1.1")
# print(f"end_non_missing = {end_non_missing}")
# print(f"p = {<int32_t>p}")

if p >= end_non_missing:
#with gil:
# print("continuing")
continue

#with gil:
# print("_node_split_best checkpoint 1.2")

current_split.pos = p

# probably want to assign this to current_split.threshold later,
# but the code is so stateful that Write Everything Twice is the
# safer move here for now
current_threshold = (
feature_values[p_prev] / 2.0 + feature_values[p] / 2.0
)

#with gil:
# print("_node_split_best checkpoint 2")

# check pre split rejection criteria
conditions_hold = True
for condition in splitter.presplit_conditions:
if not condition.f(
Expand All @@ -791,24 +755,18 @@ cdef inline intp_t node_split_best(
conditions_hold = False
break

#with gil:
# print("_node_split_best checkpoint 3")

if not conditions_hold:
continue

# Reject if min_samples_leaf is not guaranteed
# this can probably (and should) be removed as it is generalized
# by injectable split rejection criteria
if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
continue

#with gil:
# print("_node_split_best checkpoint 4")

criterion.update(current_split.pos)

#with gil:
# print("_node_split_best checkpoint 5")

# check post split rejection criteria
conditions_hold = True
for condition in splitter.postsplit_conditions:
if not condition.f(
Expand All @@ -819,15 +777,9 @@ cdef inline intp_t node_split_best(
conditions_hold = False
break

#with gil:
# print("_node_split_best checkpoint 6")

if not conditions_hold:
continue

#with gil:
# print("_node_split_best checkpoint 7")

current_proxy_improvement = criterion.proxy_impurity_improvement()

if current_proxy_improvement > best_proxy_improvement:
Expand Down Expand Up @@ -859,15 +811,9 @@ cdef inline intp_t node_split_best(

best_split = current_split # copy

#with gil:
# print("_node_split_best checkpoint 8")

# Evaluate when there are missing values and all missing values goes
# to the right node and non-missing values goes to the left node.
if has_missing:
#with gil:
# print("has_missing = {has_missing}")

n_left, n_right = end - start - n_missing, n_missing
p = end - n_missing
missing_go_to_left = 0
Expand All @@ -888,24 +834,16 @@ cdef inline intp_t node_split_best(
current_split.pos = p
best_split = current_split

#with gil:
# print("checkpoint 9")

# Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end]
if best_split.pos < end:
#with gil:
# print("checkpoint 10")

partitioner.partition_samples_final(
best_split.pos,
best_split.threshold,
best_split.feature,
best_split.n_missing
)

#with gil:
# print("checkpoint 11")

criterion.init_missing(best_split.n_missing)
criterion.missing_go_to_left = best_split.missing_go_to_left

Expand All @@ -920,37 +858,23 @@ cdef inline intp_t node_split_best(
best_split.impurity_right
)

#with gil:
# print("checkpoint 12")

shift_missing_values_to_left_if_required(&best_split, samples, end)

#with gil:
# print("checkpoint 13")

# Respect invariant for constant features: the original order of
# element in features[:n_known_constants] must be preserved for sibling
# and child nodes
memcpy(&features[0], &constant_features[0], sizeof(intp_t) * n_known_constants)

#with gil:
# print("checkpoint 14")

# Copy newly found constant features
memcpy(&constant_features[n_known_constants],
&features[n_known_constants],
sizeof(intp_t) * n_found_constants)

#with gil:
# print("checkpoint 15")

# Return values
parent_record.n_constant_features = n_total_constants
split[0] = best_split

#with gil:
# print("returning")

return 0


Expand Down

0 comments on commit f655401

Please sign in to comment.