Skip to content

Commit

Permalink
Migrate n_constant_features within SplitRecord
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Mar 10, 2024
1 parent e1224a4 commit 5ccd00f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 31 deletions.
2 changes: 1 addition & 1 deletion sklearn/tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cdef struct SplitRecord:
float64_t upper_bound # Upper bound on value of both children for monotonicity
unsigned char missing_go_to_left # Controls if missing values go to the left node.
intp_t n_missing # Number of missing values for the feature being split on
intp_t n_constant_features # Number of constant features in the split

cdef class BaseSplitter:
"""Abstract interface for splitter."""
Expand Down Expand Up @@ -90,7 +91,6 @@ cdef class BaseSplitter:
self,
float64_t impurity, # Impurity of the node
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound,
) except -1 nogil
Expand Down
24 changes: 5 additions & 19 deletions sklearn/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
self.improvement = -INFINITY
self.missing_go_to_left = False
self.n_missing = 0
self.n_constant_features = 0

cdef class BaseSplitter:
"""This is an abstract interface for splitters.
Expand Down Expand Up @@ -100,7 +101,6 @@ cdef class BaseSplitter:
self,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound
) except -1 nogil:
Expand All @@ -118,9 +118,6 @@ cdef class BaseSplitter:
split : SplitRecord pointer
A pointer to a memory-allocated SplitRecord object which will be filled with the
split chosen.
n_constant_features : intp_t pointer
A pointer to a memory-allocated intp_t object which will be filled with the
number of constant features. Optional to use.
lower_bound : float64_t
The lower bound of the monotonic constraint if used.
upper_bound : float64_t
Expand Down Expand Up @@ -322,7 +319,6 @@ cdef class Splitter(BaseSplitter):
self,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound,
) except -1 nogil:
Expand Down Expand Up @@ -444,7 +440,6 @@ cdef inline intp_t node_split_best(
Criterion criterion,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
bint with_monotonic_cst,
const cnp.int8_t[:] monotonic_cst,
float64_t lower_bound,
Expand Down Expand Up @@ -490,7 +485,7 @@ cdef inline intp_t node_split_best(
cdef intp_t n_found_constants = 0
# Number of features known to be constant and drawn without replacement
cdef intp_t n_drawn_constants = 0
cdef intp_t n_known_constants = n_constant_features[0]
cdef intp_t n_known_constants = split.n_constant_features
# n_total_constants = n_known_constants + n_found_constants
cdef intp_t n_total_constants = n_known_constants

Expand Down Expand Up @@ -711,7 +706,7 @@ cdef inline intp_t node_split_best(

# Return values
split[0] = best_split
n_constant_features[0] = n_total_constants
split.n_constant_features = n_total_constants
return 0


Expand Down Expand Up @@ -834,7 +829,6 @@ cdef inline int node_split_random(
Criterion criterion,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
bint with_monotonic_cst,
const cnp.int8_t[:] monotonic_cst,
float64_t lower_bound,
Expand Down Expand Up @@ -866,7 +860,7 @@ cdef inline int node_split_random(
cdef intp_t n_found_constants = 0
# Number of features known to be constant and drawn without replacement
cdef intp_t n_drawn_constants = 0
cdef intp_t n_known_constants = n_constant_features[0]
cdef intp_t n_known_constants = split.n_constant_features
# n_total_constants = n_known_constants + n_found_constants
cdef intp_t n_total_constants = n_known_constants
cdef intp_t n_visited_features = 0
Expand Down Expand Up @@ -1021,7 +1015,7 @@ cdef inline int node_split_random(

# Return values
split[0] = best_split
n_constant_features[0] = n_total_constants
split.n_constant_features = n_total_constants
return 0


Expand Down Expand Up @@ -1679,7 +1673,6 @@ cdef class BestSplitter(Splitter):
self,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound
) except -1 nogil:
Expand All @@ -1689,7 +1682,6 @@ cdef class BestSplitter(Splitter):
self.criterion,
impurity,
split,
n_constant_features,
self.with_monotonic_cst,
self.monotonic_cst,
lower_bound,
Expand All @@ -1715,7 +1707,6 @@ cdef class BestSparseSplitter(Splitter):
self,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound
) except -1 nogil:
Expand All @@ -1725,7 +1716,6 @@ cdef class BestSparseSplitter(Splitter):
self.criterion,
impurity,
split,
n_constant_features,
self.with_monotonic_cst,
self.monotonic_cst,
lower_bound,
Expand All @@ -1751,7 +1741,6 @@ cdef class RandomSplitter(Splitter):
self,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound
) except -1 nogil:
Expand All @@ -1761,7 +1750,6 @@ cdef class RandomSplitter(Splitter):
self.criterion,
impurity,
split,
n_constant_features,
self.with_monotonic_cst,
self.monotonic_cst,
lower_bound,
Expand All @@ -1786,7 +1774,6 @@ cdef class RandomSparseSplitter(Splitter):
self,
float64_t impurity,
SplitRecord* split,
intp_t* n_constant_features,
float64_t lower_bound,
float64_t upper_bound
) except -1 nogil:
Expand All @@ -1796,7 +1783,6 @@ cdef class RandomSparseSplitter(Splitter):
self.criterion,
impurity,
split,
n_constant_features,
self.with_monotonic_cst,
self.monotonic_cst,
lower_bound,
Expand Down
21 changes: 10 additions & 11 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ cdef class TreeBuilder:

return X, y, sample_weight


# Depth first builder ---------------------------------------------------------
# A record on the stack for depth-first tree growing
cdef struct StackRecord:
Expand All @@ -166,6 +167,7 @@ cdef struct StackRecord:
float64_t lower_bound
float64_t upper_bound


cdef class DepthFirstTreeBuilder(TreeBuilder):
"""Build a decision tree in depth-first fashion."""

Expand Down Expand Up @@ -328,7 +330,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
cdef float64_t lower_bound
cdef float64_t upper_bound
cdef float64_t middle_value
cdef intp_t n_constant_features
cdef bint is_leaf
cdef intp_t max_depth_seen = -1 if first else tree.max_depth

Expand Down Expand Up @@ -379,7 +380,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
parent = stack_record.parent
is_left = stack_record.is_left
impurity = stack_record.impurity
n_constant_features = stack_record.n_constant_features
split_ptr.n_constant_features = stack_record.n_constant_features
lower_bound = stack_record.lower_bound
upper_bound = stack_record.upper_bound

Expand All @@ -398,7 +399,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
splitter.node_split(
impurity,
split_ptr,
&n_constant_features,
lower_bound,
upper_bound
)
Expand Down Expand Up @@ -470,7 +470,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
"parent": node_id,
"is_left": 0,
"impurity": split.impurity_right,
"n_constant_features": n_constant_features,
"n_constant_features": split.n_constant_features,
"lower_bound": right_child_min,
"upper_bound": right_child_max,
})
Expand All @@ -483,7 +483,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
"parent": node_id,
"is_left": 1,
"impurity": split.impurity_left,
"n_constant_features": n_constant_features,
"n_constant_features": split.n_constant_features,
"lower_bound": left_child_min,
"upper_bound": left_child_max,
})
Expand All @@ -504,7 +504,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
parent = stack_record.parent
is_left = stack_record.is_left
impurity = stack_record.impurity
n_constant_features = stack_record.n_constant_features
split_ptr.n_constant_features = stack_record.n_constant_features
lower_bound = stack_record.lower_bound
upper_bound = stack_record.upper_bound

Expand All @@ -527,7 +527,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
splitter.node_split(
impurity,
split_ptr,
&n_constant_features,
lower_bound,
upper_bound
)
Expand Down Expand Up @@ -598,7 +597,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
"parent": node_id,
"is_left": 0,
"impurity": split.impurity_right,
"n_constant_features": n_constant_features,
"n_constant_features": split.n_constant_features,
"lower_bound": right_child_min,
"upper_bound": right_child_max,
})
Expand All @@ -611,7 +610,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
"parent": node_id,
"is_left": 1,
"impurity": split.impurity_left,
"n_constant_features": n_constant_features,
"n_constant_features": split.n_constant_features,
"lower_bound": left_child_min,
"upper_bound": left_child_max,
})
Expand Down Expand Up @@ -901,11 +900,12 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

cdef intp_t node_id
cdef intp_t n_node_samples
cdef intp_t n_constant_features = 0
cdef float64_t min_impurity_decrease = self.min_impurity_decrease
cdef float64_t weighted_n_node_samples
cdef bint is_leaf

# there are no constant features in best first splits
split_ptr.n_constant_features = 0
splitter.node_reset(start, end, &weighted_n_node_samples)

if is_first:
Expand All @@ -923,7 +923,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
splitter.node_split(
impurity,
split_ptr,
&n_constant_features,
lower_bound,
upper_bound
)
Expand Down

0 comments on commit 5ccd00f

Please sign in to comment.