From 6eaf49554f004c0f8bc4aa2e6dfbb1f14cb5807a Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 10 Oct 2024 17:04:11 -0400 Subject: [PATCH] FEA Add warning to control against runtime might when wanting to run comight (#323) * Add warning to control against runtime might when wanting to run comight * Updating submodule --------- Signed-off-by: Adam Li --- doc/sphinxext/allow_nan_estimators.py | 2 +- treeple/_lib/sklearn_fork | 2 +- treeple/ensemble/_honest_forest.py | 19 ++++++++++++------ treeple/stats/forest.py | 18 +++++++++++++++-- treeple/stats/tests/test_forest.py | 27 ++++++++++++++++++++++++++ treeple/tests/test_honest_forest.py | 1 + treeple/tree/_honest_tree.py | 1 + treeple/tree/tests/test_honest_tree.py | 6 +++++- 8 files changed, 65 insertions(+), 11 deletions(-) diff --git a/doc/sphinxext/allow_nan_estimators.py b/doc/sphinxext/allow_nan_estimators.py index 45cbf9627..9fc4e2d0f 100755 --- a/doc/sphinxext/allow_nan_estimators.py +++ b/doc/sphinxext/allow_nan_estimators.py @@ -3,8 +3,8 @@ from docutils import nodes from docutils.parsers.rst import Directive from sklearn.utils import all_estimators +from sklearn.utils._test_common.instance_generator import _construct_instances from sklearn.utils._testing import SkipTest -from sklearn.utils.estimator_checks import _construct_instance class AllowNanEstimators(Directive): diff --git a/treeple/_lib/sklearn_fork b/treeple/_lib/sklearn_fork index e4b9728cb..4fd15fdf8 160000 --- a/treeple/_lib/sklearn_fork +++ b/treeple/_lib/sklearn_fork @@ -1 +1 @@ -Subproject commit e4b9728cb8667d0a40ed0c6c45f0414811f5f1f8 +Subproject commit 4fd15fdf88737e7e84e96217b2c9b0ce0c162c2c diff --git a/treeple/ensemble/_honest_forest.py b/treeple/ensemble/_honest_forest.py index 447371b37..0650d9a31 100644 --- a/treeple/ensemble/_honest_forest.py +++ b/treeple/ensemble/_honest_forest.py @@ -10,7 +10,7 @@ from sklearn.base import _fit_context, clone from sklearn.ensemble._base import _partition_estimators, _set_random_states from sklearn.utils import compute_sample_weight, resample -from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions +from sklearn.utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions from sklearn.utils.validation import check_is_fitted from .._lib.sklearn.ensemble._forest import ForestClassifier @@ -417,11 +417,18 @@ class labels (multi-output problem). Interval(RealNotInt, 0.0, None, closed="right"), Interval(Integral, 1, None, closed="left"), ] - _parameter_constraints["honest_fraction"] = [Interval(RealNotInt, 0.0, 1.0, closed="both")] - _parameter_constraints["honest_prior"] = [ - StrOptions({"empirical", "uniform", "ignore"}), - ] - _parameter_constraints["stratify"] = ["boolean"] + _parameter_constraints.update( + { + "tree_estimator": [ + HasMethods(["fit", "predict", "predict_proba", "apply"]), + None, + ], + "honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")], + "honest_prior": [StrOptions({"empirical", "uniform", "ignore"})], + "stratify": ["boolean"], + "tree_estimator_params": ["dict"], + } + ) def __init__( self, diff --git a/treeple/stats/forest.py b/treeple/stats/forest.py index 5b5b0b056..34ea690b5 100644 --- a/treeple/stats/forest.py +++ b/treeple/stats/forest.py @@ -1,6 +1,7 @@ import threading from collections import namedtuple from typing import Callable +from warnings import warn import numpy as np from joblib import Parallel, delayed @@ -10,6 +11,8 @@ from sklearn.utils.multiclass import type_of_target from .._lib.sklearn.ensemble._forest import ForestClassifier +from ..ensemble import HonestForestClassifier +from ..tree import MultiViewDecisionTreeClassifier from ..tree._classes import DTYPE from .permuteforest import PermutationHonestForestClassifier from .utils import METRIC_FUNCTIONS, POSITIVE_METRICS, _compute_null_distribution_coleman @@ -38,8 +41,8 @@ def _parallel_predict_proba_oob(predict_proba, X, out, idx, test_idx, lock): def build_coleman_forest( - est, - perm_est, + est: HonestForestClassifier, + perm_est: PermutationHonestForestClassifier, X, y, covariate_index=None, @@ -111,6 +114,9 @@ def build_coleman_forest( """ metric_func: Callable[[ArrayLike, ArrayLike], float] = METRIC_FUNCTIONS[metric] + if not isinstance(est, HonestForestClassifier): + raise RuntimeError(f"Original forest must be a HonestForestClassifier, got {type(est)}") + # build two sets of forests est, orig_forest_proba = build_oob_forest(est, X, y, verbose=verbose) @@ -118,6 +124,14 @@ def build_coleman_forest( raise RuntimeError( f"Permutation forest must be a PermutationHonestForestClassifier, got {type(perm_est)}" ) + + if covariate_index is None and isinstance(est.tree_estimator, MultiViewDecisionTreeClassifier): + warn( + "Covariate index is not defined, but a MultiViewDecisionTreeClassifier is used. " + "If using CoMIGHT, one should define the covariate index to permute. " + "Defaulting to use MIGHT." + ) + perm_est, perm_forest_proba = build_oob_forest( perm_est, X, y, verbose=verbose, covariate_index=covariate_index ) diff --git a/treeple/stats/tests/test_forest.py b/treeple/stats/tests/test_forest.py index 922b057e5..ae59abec7 100644 --- a/treeple/stats/tests/test_forest.py +++ b/treeple/stats/tests/test_forest.py @@ -430,3 +430,30 @@ def test_build_oob_random_forest(): assert len(np.unique(structure_samples[tree_idx])) + len(oob_samples_list[tree_idx]) == len( samples ), f"{tree_idx} {len(structure_samples[tree_idx])} + {len(oob_samples_list[tree_idx])} != {len(samples)}" + + +def test_build_coleman_warning_with_multiview_without_covariate_index(): + """Test warning is raised in build_coleman_forest with multiview without covariate_index.""" + + est = HonestForestClassifier( + n_estimators=100, + random_state=0, + bootstrap=True, + max_samples=1.0, + honest_fraction=0.5, + stratify=True, + tree_estimator=MultiViewDecisionTreeClassifier(), + ) + perm_est = PermutationHonestForestClassifier( + n_estimators=100, + random_state=0, + bootstrap=True, + max_samples=1.0, + honest_fraction=0.5, + stratify=True, + tree_estimator=MultiViewDecisionTreeClassifier(), + ) + X = rng.normal(0, 1, (100, 2)) + y = np.array([0, 1] * 50) + with pytest.warns(UserWarning, match="Covariate index is not defined"): + build_coleman_forest(est, perm_est, X, y, metric="s@98", n_repeats=1000, seed=0) diff --git a/treeple/tests/test_honest_forest.py b/treeple/tests/test_honest_forest.py index f1e392109..7a6ee4568 100644 --- a/treeple/tests/test_honest_forest.py +++ b/treeple/tests/test_honest_forest.py @@ -313,6 +313,7 @@ def test_sklearn_compatible_estimator(estimator, check): # TODO: this is an error. Somehow a segfault is raised when fit is called first and # then partial_fit "check_fit_score_takes_y", + "check_do_not_raise_errors_in_init_or_set_params", ]: pytest.skip() check(estimator) diff --git a/treeple/tree/_honest_tree.py b/treeple/tree/_honest_tree.py index 7a61242d1..26f5b7ca5 100644 --- a/treeple/tree/_honest_tree.py +++ b/treeple/tree/_honest_tree.py @@ -284,6 +284,7 @@ class frequency in the voting subsample. "honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")], "honest_prior": [StrOptions({"empirical", "uniform", "ignore"})], "stratify": ["boolean"], + "tree_estimator_params": ["dict"], } def __init__( diff --git a/treeple/tree/tests/test_honest_tree.py b/treeple/tree/tests/test_honest_tree.py index 6f57f12d1..bdc714e55 100644 --- a/treeple/tree/tests/test_honest_tree.py +++ b/treeple/tree/tests/test_honest_tree.py @@ -171,7 +171,11 @@ def test_sklearn_compatible_estimator(estimator, check): # XXX: can include this "generalization" in the future if it's useful # zero sample weight is not "really supported" in honest subsample trees since sample weight # for fitting the tree's splits - if check.func.__name__ in ["check_class_weight_classifiers", "check_classifier_multioutput"]: + if check.func.__name__ in [ + "check_class_weight_classifiers", + "check_classifier_multioutput", + "check_do_not_raise_errors_in_init_or_set_params", + ]: pytest.skip() check(estimator)