Skip to content

Commit

Permalink
Adding fix to honest forest feature importances (#156)
Browse files Browse the repository at this point in the history
* Adding fix to honest forest feature importances so it can be called `est.feature_importances_` without error

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 authored Oct 26, 2023
1 parent cedcb1e commit a81c597
Show file tree
Hide file tree
Showing 9 changed files with 1,414 additions and 804 deletions.
1 change: 1 addition & 0 deletions doc/whats_new/v0.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Changelog
- |Fix| Fixes a bug where covariate indices were not shuffled by default when running FeatureImportanceForestClassifier and FeatureImportanceForestRegressor test methods, by `Sambit Panda`_ (:pr:`140`)
- |Enhancement| Add multi-view splitter for axis-aligned decision trees, by `Adam Li`_ (:pr:`129`)
- |Enhancement| Add stratified sampling option to ``FeatureImportance*`` via the ``stratify`` keyword argument, by `Yuxin Bai`_ (:pr:`143`)
- |Fix| Fixed usage of ``feature_importances_`` property in ``HonestForestClassifier``, by `Adam Li`_ (:pr:`156`)

Code and Documentation Contributors
-----------------------------------
Expand Down
2,134 changes: 1,340 additions & 794 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ optional = true
[tool.poetry.group.style.dependencies]
poethepoet = "^0.16.0"
mypy = "^0.971"
black = {extras = ["jupyter"], version = "^22.12.0"}
black = {extras = ["jupyter"], version = "^23.10.1"}
isort = "^5.10.1"
flake8 = "^5.0.4"
bandit = "^1.7.4"
Expand Down
5 changes: 0 additions & 5 deletions sktree/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,6 @@ def honest_indices_(self):
check_is_fitted(self)
return [tree.honest_indices_ for tree in self.estimators_]

@property
def feature_importances_(self):
"""The feature importances."""
return self.estimator_.feature_importances_

def _more_tags(self):
return {"multioutput": False}

Expand Down
6 changes: 5 additions & 1 deletion sktree/ensemble/_unsupervised_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def fit(self, X, y=None, sample_weight=None):
# that case. However, for joblib 0.12+ we respect any
# parallel_backend contexts set at a higher level,
# since correctness does not rely on using threads.
trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer="threads",)(
trees = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose,
prefer="threads",
)(
delayed(_parallel_build_trees)(
t,
self.bootstrap,
Expand Down
60 changes: 60 additions & 0 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from sklearn import datasets
from sklearn.metrics import accuracy_score, r2_score
from sklearn.utils import check_random_state
from sklearn.utils.estimator_checks import parametrize_with_checks

from sktree._lib.sklearn.tree import DecisionTreeClassifier
Expand All @@ -18,6 +20,17 @@
iris.data = iris.data[perm]
iris.target = iris.target[perm]

# Larger classification sample used for testing feature importances
X_large, y_large = datasets.make_classification(
n_samples=500,
n_features=10,
n_informative=3,
n_redundant=0,
n_repeated=0,
shuffle=False,
random_state=0,
)


def test_toy_accuracy():
clf = HonestForestClassifier(n_estimators=10)
Expand Down Expand Up @@ -192,3 +205,50 @@ def test_sklearn_compatible_estimator(estimator, check):
]:
pytest.skip()
check(estimator)


@pytest.mark.parametrize("dtype", (np.float64, np.float32))
@pytest.mark.parametrize("criterion", ["gini", "log_loss"])
def test_importances(dtype, criterion):
"""Ported from sklearn unit-test.
Used to ensure that honest forest feature importances are consistent with sklearn's.
"""
tolerance = 0.01

# cast as dtype
X = X_large.astype(dtype, copy=False)
y = y_large.astype(dtype, copy=False)

ForestEstimator = HonestForestClassifier

est = ForestEstimator(n_estimators=10, criterion=criterion, random_state=0)
est.fit(X, y)

importances = est.feature_importances_

# The forest estimator can detect that only the first 3 features of the
# dataset are informative:
n_important = np.sum(importances > 0.1)
assert importances.shape[0] == 10
assert n_important == 3
assert np.all(importances[:3] > 0.1)

# Check with parallel
importances = est.feature_importances_
est.set_params(n_jobs=2)
importances_parallel = est.feature_importances_
assert_array_almost_equal(importances, importances_parallel)

# Check with sample weights
sample_weight = check_random_state(0).randint(1, 10, len(X))
est = ForestEstimator(n_estimators=10, random_state=0, criterion=criterion)
est.fit(X, y, sample_weight=sample_weight)
importances = est.feature_importances_
assert np.all(importances >= 0.0)

for scale in [0.5, 100]:
est = ForestEstimator(n_estimators=10, random_state=0, criterion=criterion)
est.fit(X, y, sample_weight=scale * sample_weight)
importances_bis = est.feature_importances_
assert np.abs(importances - importances_bis).mean() < tolerance
2 changes: 1 addition & 1 deletion sktree/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class HonestTreeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseDecisionTree
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
tree_estimator : object, default=None
Instatiated tree of type BaseDecisionTree from sktree.
Instantiated tree of type BaseDecisionTree from sktree.
If None, then DecisionTreeClassifier with default parameters will
be used. Note that one MUST use trees imported from the `sktree.tree`
API namespace rather than from `sklearn.tree`.
Expand Down
2 changes: 1 addition & 1 deletion sktree/tree/_marginal.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ cdef inline cnp.ndarray _apply_dense_marginal(
n_left_samples = tree.nodes[node.left_child].n_node_samples
n_right_samples = tree.nodes[node.right_child].n_node_samples

# compute the probabilies for going left and right
# compute the probabilities for going left and right
p_left = (<float64_t>n_left_samples / n_node_samples)

# randomly sample a direction
Expand Down
6 changes: 5 additions & 1 deletion sktree/tree/_marginalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def _apply_marginal_forest(
if est.max_bins is not None:
X = est._bin_data(X, is_training_data=False).astype(DTYPE)

results = Parallel(n_jobs=est.n_jobs, verbose=est.verbose, prefer="threads",)(
results = Parallel(
n_jobs=est.n_jobs,
verbose=est.verbose,
prefer="threads",
)(
delayed(_apply_marginal_tree)(
tree,
X,
Expand Down

0 comments on commit a81c597

Please sign in to comment.