diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index 423fc0ec6449a..dd39b8cb607a8 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -41,6 +41,7 @@ _check_method_params, _check_sample_weight, _deprecate_positional_args, + _estimator_has, check_is_fitted, has_fit_parameter, validate_data, @@ -269,22 +270,6 @@ def _parallel_predict_regression(estimators, estimators_features, X): ) -def _estimator_has(attr): - """Check if we can delegate a method to the underlying estimator. - - First, we check the first fitted estimator if available, otherwise we - check the estimator attribute. - """ - - def check(self): - if hasattr(self, "estimators_"): - return hasattr(self.estimators_[0], attr) - else: # self.estimator is not None - return hasattr(self.estimator, attr) - - return check - - class BaseBagging(BaseEnsemble, metaclass=ABCMeta): """Base class for Bagging meta-estimator. @@ -1033,7 +1018,9 @@ def predict_log_proba(self, X): return log_proba - @available_if(_estimator_has("decision_function")) + @available_if( + _estimator_has("decision_function", delegates=("estimators_", "estimator")) + ) def decision_function(self, X): """Average of the decision functions of the base classifiers. diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 57bc63a1862e9..bf5ff39c13165 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -40,31 +40,13 @@ _check_feature_names_in, _check_response_method, _deprecate_positional_args, + _estimator_has, check_is_fitted, column_or_1d, ) from ._base import _BaseHeterogeneousEnsemble, _fit_single_estimator -def _estimator_has(attr): - """Check if we can delegate a method to the underlying estimator. - - First, we check the fitted `final_estimator_` if available, otherwise we check the - unfitted `final_estimator`. We raise the original `AttributeError` if `attr` does - not exist. This function is used together with `available_if`. - """ - - def check(self): - if hasattr(self, "final_estimator_"): - getattr(self.final_estimator_, attr) - else: - getattr(self.final_estimator, attr) - - return True - - return check - - class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta): """Base class for stacking method.""" @@ -364,7 +346,9 @@ def get_feature_names_out(self, input_features=None): return np.asarray(meta_names, dtype=object) - @available_if(_estimator_has("predict")) + @available_if( + _estimator_has("predict", delegates=("final_estimator_", "final_estimator")) + ) def predict(self, X, **predict_params): """Predict target for X. @@ -732,7 +716,9 @@ def fit(self, X, y, *, sample_weight=None, **fit_params): fit_params["sample_weight"] = sample_weight return super().fit(X, y_encoded, **fit_params) - @available_if(_estimator_has("predict")) + @available_if( + _estimator_has("predict", delegates=("final_estimator_", "final_estimator")) + ) def predict(self, X, **predict_params): """Predict target for X. @@ -785,7 +771,11 @@ def predict(self, X, **predict_params): y_pred = self._label_encoder.inverse_transform(y_pred) return y_pred - @available_if(_estimator_has("predict_proba")) + @available_if( + _estimator_has( + "predict_proba", delegates=("final_estimator_", "final_estimator") + ) + ) def predict_proba(self, X): """Predict class probabilities for `X` using the final estimator. @@ -809,7 +799,11 @@ def predict_proba(self, X): y_pred = np.array([preds[:, 0] for preds in y_pred]).T return y_pred - @available_if(_estimator_has("decision_function")) + @available_if( + _estimator_has( + "decision_function", delegates=("final_estimator_", "final_estimator") + ) + ) def decision_function(self, X): """Decision function for samples in `X` using the final estimator. @@ -1125,7 +1119,9 @@ def fit_transform(self, X, y, *, sample_weight=None, **fit_params): fit_params["sample_weight"] = sample_weight return super().fit_transform(X, y, **fit_params) - @available_if(_estimator_has("predict")) + @available_if( + _estimator_has("predict", delegates=("final_estimator_", "final_estimator")) + ) def predict(self, X, **predict_params): """Predict target for X. diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index d5476e3f06abf..28af66d524623 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -19,6 +19,7 @@ from ..utils.metaestimators import available_if from ..utils.validation import ( _check_feature_names, + _estimator_has, _num_features, check_is_fitted, check_scalar, @@ -76,25 +77,6 @@ def _calculate_threshold(estimator, importances, threshold): return threshold -def _estimator_has(attr): - """Check if we can delegate a method to the underlying estimator. - - First, we check the fitted `estimator_` if available, otherwise we check the - unfitted `estimator`. We raise the original `AttributeError` if `attr` does - not exist. This function is used together with `available_if`. - """ - - def check(self): - if hasattr(self, "estimator_"): - getattr(self.estimator_, attr) - else: - getattr(self.estimator, attr) - - return True - - return check - - class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator): """Meta-transformer for selecting features based on importance weights. diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index bbd7a80ead458..bd6a28b97b557 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -29,6 +29,7 @@ from ..utils.validation import ( _check_method_params, _deprecate_positional_args, + _estimator_has, check_is_fitted, validate_data, ) @@ -64,25 +65,6 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer, routed_params): return rfe.step_scores_, rfe.step_n_features_ -def _estimator_has(attr): - """Check if we can delegate a method to the underlying estimator. - - First, we check the fitted `estimator_` if available, otherwise we check the - unfitted `estimator`. We raise the original `AttributeError` if `attr` does - not exist. This function is used together with `available_if`. - """ - - def check(self): - if hasattr(self, "estimator_"): - getattr(self.estimator_, attr) - else: - getattr(self.estimator, attr) - - return True - - return check - - class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator): """Feature ranking with recursive feature elimination. diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index 86c982385f5ee..8ac7a67a03433 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -36,6 +36,7 @@ from ..utils.parallel import Parallel, delayed from ..utils.validation import ( _check_method_params, + _estimator_has, _num_samples, check_is_fitted, indexable, @@ -50,23 +51,6 @@ def _check_is_fitted(estimator): check_is_fitted(estimator, "estimator_") -def _estimator_has(attr): - """Check if we can delegate a method to the underlying estimator. - - First, we check the fitted estimator if available, otherwise we - check the unfitted estimator. - """ - - def check(self): - if hasattr(self, "estimator_"): - getattr(self.estimator_, attr) - else: - getattr(self.estimator, attr) - return True - - return check - - class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): """Base class for binary classifiers that set a non-default decision threshold. diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 5a8284c49888b..a8431b74259b4 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -356,7 +356,7 @@ def _check_refit(search_cv, attr): ) -def _estimator_has(attr): +def _search_estimator_has(attr): """Check if we can delegate a method to the underlying estimator. Calling a prediction method will only be available if `refit=True`. In @@ -555,7 +555,7 @@ def score(self, X, y=None, **params): score = score[self.refit] return score - @available_if(_estimator_has("score_samples")) + @available_if(_search_estimator_has("score_samples")) def score_samples(self, X): """Call score_samples on the estimator with the best found parameters. @@ -578,7 +578,7 @@ def score_samples(self, X): check_is_fitted(self) return self.best_estimator_.score_samples(X) - @available_if(_estimator_has("predict")) + @available_if(_search_estimator_has("predict")) def predict(self, X): """Call predict on the estimator with the best found parameters. @@ -600,7 +600,7 @@ def predict(self, X): check_is_fitted(self) return self.best_estimator_.predict(X) - @available_if(_estimator_has("predict_proba")) + @available_if(_search_estimator_has("predict_proba")) def predict_proba(self, X): """Call predict_proba on the estimator with the best found parameters. @@ -623,7 +623,7 @@ def predict_proba(self, X): check_is_fitted(self) return self.best_estimator_.predict_proba(X) - @available_if(_estimator_has("predict_log_proba")) + @available_if(_search_estimator_has("predict_log_proba")) def predict_log_proba(self, X): """Call predict_log_proba on the estimator with the best found parameters. @@ -646,7 +646,7 @@ def predict_log_proba(self, X): check_is_fitted(self) return self.best_estimator_.predict_log_proba(X) - @available_if(_estimator_has("decision_function")) + @available_if(_search_estimator_has("decision_function")) def decision_function(self, X): """Call decision_function on the estimator with the best found parameters. @@ -669,7 +669,7 @@ def decision_function(self, X): check_is_fitted(self) return self.best_estimator_.decision_function(X) - @available_if(_estimator_has("transform")) + @available_if(_search_estimator_has("transform")) def transform(self, X): """Call transform on the estimator with the best found parameters. @@ -691,7 +691,7 @@ def transform(self, X): check_is_fitted(self) return self.best_estimator_.transform(X) - @available_if(_estimator_has("inverse_transform")) + @available_if(_search_estimator_has("inverse_transform")) def inverse_transform(self, X=None, Xt=None): """Call inverse_transform on the estimator with the best found params. @@ -746,7 +746,7 @@ def classes_(self): Only available when `refit=True` and the estimator is a classifier. """ - _estimator_has("classes_")(self) + _search_estimator_has("classes_")(self) return self.best_estimator_.classes_ def _run_search(self, evaluate_candidates): diff --git a/sklearn/semi_supervised/_self_training.py b/sklearn/semi_supervised/_self_training.py index 5ac0b8ca28533..d56ebf887828c 100644 --- a/sklearn/semi_supervised/_self_training.py +++ b/sklearn/semi_supervised/_self_training.py @@ -17,7 +17,7 @@ process_routing, ) from ..utils.metaestimators import available_if -from ..utils.validation import check_is_fitted, validate_data +from ..utils.validation import _estimator_has, check_is_fitted, validate_data __all__ = ["SelfTrainingClassifier"] @@ -25,25 +25,6 @@ # SPDX-License-Identifier: BSD-3-Clause -def _estimator_has(attr): - """Check if we can delegate a method to the underlying estimator. - - First, we check the fitted `estimator_` if available, otherwise we check - the unfitted `estimator`. We raise the original `AttributeError` if - `attr` does not exist. This function is used together with `available_if`. - """ - - def check(self): - if hasattr(self, "estimator_"): - getattr(self.estimator_, attr) - else: - getattr(self.estimator, attr) - - return True - - return check - - class SelfTrainingClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): """Self-training classifier. diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 20aee5b439252..5ae5a003d0d0a 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -67,6 +67,7 @@ _check_sample_weight, _check_y, _deprecate_positional_args, + _estimator_has, _get_feature_names, _is_fitted, _is_pandas_df, @@ -1163,6 +1164,93 @@ def test_check_array_memmap(copy): assert X_checked.flags["WRITEABLE"] == copy +@pytest.mark.parametrize( + "estimator_name, estimator_value, delegates, expected_result, expected_exception", + [ + ( + "estimator_", + type("SubEstimator", (), {"attribute_present": True}), + None, # default delegates - ["estimator_", "estimator"] + True, # expected_result is True b/c delegate and attribute are present + None, # expected_exception not relevant for this case + ), + ( + "estimator", + type("SubEstimator", (), {"attribute_present": True}), + None, # default delegates - ["estimator_", "estimator"] + True, # expected_result is True b/c delegate and attribute are present + None, # expected_exception not relevant for this case + ), + ( + "estimators_", + [ + type("SubEstimator", (), {"attribute_present": True}) + ], # list of sub-estimators + ["estimators_"], + True, # expected_result is True b/c delegate and attribute are present + None, # expected_exception not relevant for this case + ), + ( + "custom_estimator", # custom estimator attribute name + type("SubEstimator", (), {"attribute_present": True}), + ["custom_estimator"], # custom delegates + True, # expected_result is True b/c delegate and attribute are present + None, # expected_exception not relevant for this case + ), + ( + "no_estimator", # no estimator attribute name + type("SubEstimator", (), {"attribute_present": True}), + None, # default delegates - ["estimator_", "estimator"] + None, # expected_result is not relevant for this case + ValueError, # should raise ValueError b/c no estimator found from delegates + ), + ( + "estimator", + type("SubEstimator", (), {"attribute_absent": True}), # attribute_absent + None, # default delegates - ["estimator_", "estimator"] + None, # expected_result is not relevant for this case + AttributeError, # should raise AttributeError b/c attribute is absent + ), + ], + ids=[ + "fitted_estimator_with_default_delegates", + "estimator_with_default_delegates", + "list_of_estimators_with_estimators_", + "custom_estimator_with_custom_delegates", + "no_estimator_with_default_delegates", + "estimator_with_default_delegates_but_absent_attribute", + ], +) +def test_estimator_has( + estimator_name, estimator_value, delegates, expected_result, expected_exception +): + """ + Tests the _estimator_has function by verifying: + - Functionality with default and custom delegates. + - Raises ValueError if delegates are missing. + - Raises AttributeError if the specified attribute is missing. + """ + + # always checks for attribute - "attribute_present" + # ["estimator_", "estimator"] is default value for delegates + if delegates is None: + check = _estimator_has("attribute_present") + else: + check = _estimator_has("attribute_present", delegates=delegates) + + class MockEstimator: + pass + + a = MockEstimator() + setattr(a, estimator_name, estimator_value) + + if expected_exception: + with pytest.raises(expected_exception): + check(a) + else: + assert check(a) == expected_result + + @pytest.mark.parametrize( "retype", [ diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 383e262e0971e..649df1de8f223 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -7,6 +7,7 @@ import operator import sys import warnings +from collections.abc import Sequence from contextlib import suppress from functools import reduce, wraps from inspect import Parameter, isclass, signature @@ -1738,6 +1739,48 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all): raise NotFittedError(msg % {"name": type(estimator).__name__}) +def _estimator_has(attr, *, delegates=("estimator_", "estimator")): + """Check if we can delegate a method to the underlying estimator. + + We check the `delegates` in the order they are passed. By default, we first check + the fitted estimator if available, otherwise we check the unfitted estimator. + + Parameters + ---------- + attr : str + Name of the attribute the delegate might or might not have. + + delegates: tuple of str, default=("estimator_", "estimator") + A tuple of sub-estimator(s) to check if we can delegate the `attr` method. + + Returns + ------- + check : function + Function to check if the delegate has the attribute. + + Raises + ------ + ValueError + Raised when none of the delegates are present in the object. + """ + + def check(self): + for delegate in delegates: + # In meta estimators with multiple sub estimators, + # only the attribute of the first sub estimator is checked, + # assuming uniformity across all sub estimators. + if hasattr(self, delegate): + delegator = getattr(self, delegate) + if isinstance(delegator, Sequence): + return getattr(delegator[0], attr) + else: + return getattr(delegator, attr) + + raise ValueError(f"None of the delegates {delegates} are present in the class.") + + return check + + def check_non_negative(X, whom): """ Check if there is any negative value in an array.