Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scarliles/splitter injection #61

Open
wants to merge 20 commits into
base: submodulev3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions sklearn/tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,60 @@
# Jacob Schreiber <jmschreiber91@gmail.com>
# Adam Li <adam2392@gmail.com>
# Jong Shin <jshinm@gmail.com>
# Samuel Carliles <scarlil1@jhu.edu>
#
# License: BSD 3 clause

# See _splitter.pyx for details.
cimport numpy as cnp

from libcpp.vector cimport vector
from libc.stdlib cimport malloc

from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t
from ._utils cimport UINT32_t
from ._criterion cimport BaseCriterion, Criterion


# NICE IDEAS THAT DON'T APPEAR POSSIBLE
# - accessing elements of a memory view of cython extension types in a nogil block/function
# - storing cython extension types in cpp vectors
#
# despite the fact that we can access scalar extension type properties in such a context,
# as for instance node_split_best does with Criterion and Partition,
# and we can access the elements of a memory view of primitive types in such a context
#
# SO WHERE DOES THAT LEAVE US
# - we can transform these into cpp vectors of structs
# and with some minor casting irritations everything else works ok
ctypedef void* SplitConditionParameters
ctypedef bint (*SplitConditionFunction)(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil

cdef struct SplitConditionTuple:
SplitConditionFunction f
SplitConditionParameters p

cdef class SplitCondition:
cdef SplitConditionTuple t

cdef class MinSamplesLeafCondition(SplitCondition):
pass

cdef class MinWeightLeafCondition(SplitCondition):
pass

cdef class MonotonicConstraintCondition(SplitCondition):
pass


cdef struct SplitRecord:
# Data to track sample split
intp_t feature # Which feature to split on.
Expand Down Expand Up @@ -112,6 +153,13 @@ cdef class Splitter(BaseSplitter):
cdef const cnp.int8_t[:] monotonic_cst
cdef bint with_monotonic_cst

cdef SplitCondition min_samples_leaf_condition
cdef SplitCondition min_weight_leaf_condition
cdef SplitCondition monotonic_constraint_condition

cdef vector[SplitConditionTuple] presplit_conditions
cdef vector[SplitConditionTuple] postsplit_conditions

cdef int init(
self,
object X,
Expand Down
269 changes: 233 additions & 36 deletions sklearn/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from cython cimport final
from libc.math cimport isnan
from libc.stdlib cimport qsort
from libc.stdlib cimport qsort, free
from libc.string cimport memcpy
cimport numpy as cnp

Expand All @@ -43,6 +43,155 @@ cdef float32_t FEATURE_THRESHOLD = 1e-7
# in SparsePartitioner
cdef float32_t EXTRACT_NNZ_SWITCH = 0.1


cdef bint min_sample_leaf_condition(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil:
cdef intp_t min_samples_leaf = splitter.min_samples_leaf
cdef intp_t end_non_missing = splitter.end - n_missing
cdef intp_t n_left, n_right

if missing_go_to_left:
n_left = current_split.pos - splitter.start + n_missing
n_right = end_non_missing - current_split.pos
else:
n_left = current_split.pos - splitter.start
n_right = end_non_missing - current_split.pos + n_missing

# Reject if min_samples_leaf is not guaranteed
if n_left < min_samples_leaf or n_right < min_samples_leaf:
return False

return True

cdef class MinSamplesLeafCondition(SplitCondition):
def __cinit__(self):
self.t.f = min_sample_leaf_condition
self.t.p = NULL # min_samples is stored in splitter, which is already passed to f

cdef bint min_weight_leaf_condition(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil:
cdef float64_t min_weight_leaf = splitter.min_weight_leaf

# Reject if min_weight_leaf is not satisfied
if ((splitter.criterion.weighted_n_left < min_weight_leaf) or
(splitter.criterion.weighted_n_right < min_weight_leaf)):
return False

return True

cdef class MinWeightLeafCondition(SplitCondition):
def __cinit__(self):
self.t.f = min_weight_leaf_condition
self.t.p = NULL # min_weight_leaf is stored in splitter, which is already passed to f

cdef bint monotonic_constraint_condition(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil:
if (
splitter.with_monotonic_cst and
splitter.monotonic_cst[current_split.feature] != 0 and
not splitter.criterion.check_monotonicity(
splitter.monotonic_cst[current_split.feature],
lower_bound,
upper_bound,
)
):
return False

return True

cdef class MonotonicConstraintCondition(SplitCondition):
def __cinit__(self):
self.t.f = monotonic_constraint_condition
self.t.p = NULL

cdef struct HasDataParameters:
int min_samples

cdef bint has_data_condition(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil:
cdef HasDataParameters* p = <HasDataParameters*>split_condition_parameters
return splitter.n_samples >= p.min_samples

cdef class HasDataCondition(SplitCondition):
def __cinit__(self, int min_samples):
self.t.f = has_data_condition
self.t.p = malloc(sizeof(HasDataParameters))
(<HasDataParameters*>self.t.p).min_samples = min_samples

def __dealloc__(self):
if self.t.p is not NULL:
free(self.t.p)

super.__dealloc__(self)

cdef struct AlphaRegularityParameters:
float64_t alpha

cdef bint alpha_regularity_condition(
Splitter splitter,
SplitRecord* current_split,
intp_t n_missing,
bint missing_go_to_left,
float64_t lower_bound,
float64_t upper_bound,
SplitConditionParameters split_condition_parameters
) noexcept nogil:
cdef AlphaRegularityParameters* p = <AlphaRegularityParameters*>split_condition_parameters

return True

cdef class AlphaRegularityCondition(SplitCondition):
def __cinit__(self, float64_t alpha):
self.t.f = alpha_regularity_condition
self.t.p = malloc(sizeof(AlphaRegularityParameters))
(<AlphaRegularityParameters*>self.t.p).alpha = alpha

def __dealloc__(self):
if self.t.p is not NULL:
free(self.t.p)

super.__dealloc__(self)


from ._tree cimport Tree
cdef class FooTree(Tree):
cdef Splitter splitter

def __init__(self):
self.splitter = Splitter(
presplit_conditions = [HasDataCondition(10)],
postsplit_conditions = [AlphaRegularityCondition(0.1)],
)


cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil:
self.impurity_left = INFINITY
self.impurity_right = INFINITY
Expand Down Expand Up @@ -155,6 +304,8 @@ cdef class Splitter(BaseSplitter):
float64_t min_weight_leaf,
object random_state,
const cnp.int8_t[:] monotonic_cst,
SplitCondition[:] presplit_conditions = None,
SplitCondition[:] postsplit_conditions = None,
*argv
):
"""
Expand Down Expand Up @@ -195,6 +346,25 @@ cdef class Splitter(BaseSplitter):
self.monotonic_cst = monotonic_cst
self.with_monotonic_cst = monotonic_cst is not None

self.min_samples_leaf_condition = MinSamplesLeafCondition()
self.min_weight_leaf_condition = MinWeightLeafCondition()

self.presplit_conditions.push_back((<SplitCondition>self.min_samples_leaf_condition).t)
if presplit_conditions is not None:
for condition in presplit_conditions:
self.presplit_conditions.push_back((<SplitCondition>condition).t)

self.postsplit_conditions.push_back((<SplitCondition>self.min_weight_leaf_condition).t)
if postsplit_conditions is not None:
for condition in postsplit_conditions:
self.postsplit_conditions.push_back((<SplitCondition>condition).t)

if(self.with_monotonic_cst):
self.monotonic_constraint_condition = MonotonicConstraintCondition()
self.presplit_conditions.push_back((<SplitCondition>self.monotonic_constraint_condition).t)
self.postsplit_conditions.push_back((<SplitCondition>self.monotonic_constraint_condition).t)


def __reduce__(self):
return (type(self), (self.criterion,
self.max_features,
Expand Down Expand Up @@ -487,6 +657,8 @@ cdef inline intp_t node_split_best(
# n_total_constants = n_known_constants + n_found_constants
cdef intp_t n_total_constants = n_known_constants

cdef bint conditions_hold = True

_init_split(&best_split, end)

partitioner.init_node_split(start, end)
Expand Down Expand Up @@ -581,46 +753,71 @@ cdef inline intp_t node_split_best(

current_split.pos = p

# Reject if monotonicity constraints are not satisfied
if (
with_monotonic_cst and
monotonic_cst[current_split.feature] != 0 and
not criterion.check_monotonicity(
monotonic_cst[current_split.feature],
lower_bound,
upper_bound,
)
):
continue

# Reject if min_samples_leaf is not guaranteed
if missing_go_to_left:
n_left = current_split.pos - splitter.start + n_missing
n_right = end_non_missing - current_split.pos
else:
n_left = current_split.pos - splitter.start
n_right = end_non_missing - current_split.pos + n_missing
if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
# # Reject if monotonicity constraints are not satisfied
# if (
# with_monotonic_cst and
# monotonic_cst[current_split.feature] != 0 and
# not criterion.check_monotonicity(
# monotonic_cst[current_split.feature],
# lower_bound,
# upper_bound,
# )
# ):
# continue

# # Reject if min_samples_leaf is not guaranteed
# if missing_go_to_left:
# n_left = current_split.pos - splitter.start + n_missing
# n_right = end_non_missing - current_split.pos
# else:
# n_left = current_split.pos - splitter.start
# n_right = end_non_missing - current_split.pos + n_missing

conditions_hold = True
for condition in splitter.presplit_conditions:
if not condition.f(
splitter, &current_split, n_missing, missing_go_to_left,
lower_bound, upper_bound, condition.p
):
conditions_hold = False
break

if not conditions_hold:
continue

# if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
# continue

criterion.update(current_split.pos)

# Reject if monotonicity constraints are not satisfied
if (
with_monotonic_cst and
monotonic_cst[current_split.feature] != 0 and
not criterion.check_monotonicity(
monotonic_cst[current_split.feature],
lower_bound,
upper_bound,
)
):
continue

# Reject if min_weight_leaf is not satisfied
if splitter.check_postsplit_conditions() == 1:
# # Reject if monotonicity constraints are not satisfied
# if (
# with_monotonic_cst and
# monotonic_cst[current_split.feature] != 0 and
# not criterion.check_monotonicity(
# monotonic_cst[current_split.feature],
# lower_bound,
# upper_bound,
# )
# ):
# continue

conditions_hold = True
for condition in splitter.postsplit_conditions:
if not condition.f(
splitter, &current_split, n_missing, missing_go_to_left,
lower_bound, upper_bound, condition.p
):
conditions_hold = False
break

if not conditions_hold:
continue


# # Reject if min_weight_leaf is not satisfied
# if splitter.check_postsplit_conditions() == 1:
# continue

current_proxy_improvement = criterion.proxy_impurity_improvement()

if current_proxy_improvement > best_proxy_improvement:
Expand Down
Loading