Skip to content

Commit

Permalink
honesty wip
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelCarliles3 committed Jul 5, 2024
1 parent 69fc530 commit 61dfd0f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
22 changes: 13 additions & 9 deletions sklearn/tree/_honesty.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

# See _honesty.pyx for details.

from ._events cimport EventHandler
from ._splitter cimport Partitioner, NodeSplitEvent, NodeSortFeatureEventData, NodeSplitEventData
from ._events cimport EventData, EventHandler, EventHandlerEnv, EventType
from ._splitter cimport Partitioner, Splitter
from ._splitter cimport NodeSplitEvent, NodeSortFeatureEventData, NodeSplitEventData
from ._splitter cimport SplitConditionEnv, SplitConditionFunction, SplitConditionClosure, SplitCondition
from ._tree cimport TreeBuildEvent, TreeBuildSetActiveParentEventData, TreeBuildAddNodeEventData

Expand All @@ -21,27 +22,30 @@ cdef struct Interval:
intp_t split_idx # start of right child
float64_t split_value

cdef struct HonestEnv:
const float32_t[:, :] X
intp_t[::1] samples
float32_t[::1] feature_values
cdef class Views:
cdef:
const float32_t[:, :] X
intp_t[::1] samples
float32_t[::1] feature_values
Partitioner partitioner

cdef struct HonestEnv:
void* data_views
vector[Interval] tree
Interval* active_parent
Interval active_node
intp_t active_is_left
Partitioner partitioner

cdef class Honesty:
cdef:
object splitter_event_handlers # python list of EventHandler
object split_conditions # python list of SplitCondition
object tree_event_handlers # python list of EventHandler

Views views
HonestEnv env
Partitioner partitioner

cdef struct MinSampleLeafConditionEnv:
cdef struct MinSamplesLeafConditionEnv:
intp_t min_samples
HonestEnv* honest_env

Expand Down
14 changes: 7 additions & 7 deletions sklearn/tree/_honesty.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ cdef class Honesty:
def __cinit__(
self,
Partitioner honest_partitioner,
intp_t min_samples_leaf,
list splitter_event_handlers = None,
list split_conditions = None,
list tree_event_handlers = None,
intp_t min_samples_leaf
list tree_event_handlers = None
):
if splitter_event_handlers is None:
splitter_event_handlers = []
Expand All @@ -17,7 +17,7 @@ cdef class Honesty:
if tree_event_handlers is None:
tree_event_handlers = []

self.env.partitioner = honest_partitioner
(<Views>self.env.data_views).partitioner = honest_partitioner
self.splitter_event_handlers = [NodeSortFeatureHandler(&self.env)] + splitter_event_handlers
self.split_conditions = [HonestMinSamplesLeafCondition(min_samples_leaf, &self.env)] + split_conditions
self.tree_event_handlers = [SetActiveParentHandler(&self.env), AddNodeHandler(&self.env)] + tree_event_handlers
Expand Down Expand Up @@ -80,7 +80,7 @@ cdef bint _handle_sort_feature(

cdef HonestEnv* env = <HonestEnv*>handler_env
cdef NodeSortFeatureEventData* data = <NodeSortFeatureEventData*>event_data
cdev Interval* node = &env.active_node
cdef Interval* node = &env.active_node

node.feature = data.feature
node.split_idx = 0
Expand All @@ -106,9 +106,9 @@ cdef bint _handle_add_node(
if event_type != TreeBuildEvent.ADD_NODE:
return True

cdef HonestEnv* env = <HonestEnv*>handler_env
cdef float64_t h, feature_value
cdef intp_t i, n_left, n_missing, size = env.tree.size()
cdef HonestEnv* env = <HonestEnv*>handler_env
cdef TreeBuildAddNodeEventData* data = <TreeBuildAddNodeEventData*>event_data
cdef Interval *interval, *parent

Expand Down Expand Up @@ -146,7 +146,7 @@ cdef bint _handle_add_node(
i = interval.start_idx
feature_value = env.X[env.samples[i], interval.feature]

while !isnan(feature_value) && feature_value < interval.split_value && i < interval.start_idx + interval.n:
while (not isnan(feature_value)) and feature_value < interval.split_value and i < interval.start_idx + interval.n:
n_left += 1
i += 1
feature_value = env.X[env.samples[i], interval.feature]
Expand Down Expand Up @@ -190,7 +190,7 @@ cdef bint _honest_min_sample_leaf_condition(

# we don't care about split_pos in the structure set,
# need to scan forward in the honest set based on split_value to find it
while node.split_idx < node.start_idx + node.n && env.X[node.split_idx, node.feature] <= split_value:
while node.split_idx < node.start_idx + node.n and env.X[node.split_idx, node.feature] <= split_value:
node.split_idx += 1

if missing_go_to_left:
Expand Down

0 comments on commit 61dfd0f

Please sign in to comment.