diff --git a/treeple/neighbors.py b/treeple/neighbors.py index c95c2968..b16e732f 100644 --- a/treeple/neighbors.py +++ b/treeple/neighbors.py @@ -7,6 +7,7 @@ from sklearn.neighbors import NearestNeighbors from sklearn.utils.validation import check_is_fitted, validate_data +from treeple.tree import DecisionTreeClassifier from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix @@ -31,13 +32,19 @@ class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin): The number of parallel jobs to run for neighbors, by default None. """ - def __init__(self, estimator, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None): + def __init__(self, estimator=None, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None): self.estimator = estimator self.n_neighbors = n_neighbors self.algorithm = algorithm self.radius = radius self.n_jobs = n_jobs + def get_estimator(self): + if self.estimator is not None: + return DecisionTreeClassifier(random_state=0) + else: + return copy(self.estimator) + def fit(self, X, y=None): """Fit the nearest neighbors estimator from the training dataset. @@ -58,7 +65,7 @@ def fit(self, X, y=None): """ X, y = validate_data(self, X, y, accept_sparse="csc") - self.estimator_ = copy(self.estimator) + self.estimator_ = self.get_estimator() try: check_is_fitted(self.estimator_) except NotFittedError: