Skip to content

Commit

Permalink
Update submodule
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Sep 9, 2024
1 parent cd793cd commit 406abfd
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 15 deletions.
2 changes: 1 addition & 1 deletion treeple/_lib/sklearn_fork
8 changes: 6 additions & 2 deletions treeple/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,12 @@ def oob_samples_(self):
oob_samples.append(_oob_samples)
return oob_samples

def _more_tags(self):
return {"multioutput": False}
def __sklearn_tags__(self):
# XXX: nans should be supportable in HRF
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_output = False
tags.input_tags.allow_nan = False
return tags

def decision_path(self, X):
"""
Expand Down
10 changes: 8 additions & 2 deletions treeple/ensemble/_unsupervised_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
)
from sklearn.metrics import calinski_harabasz_score
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_random_state
from sklearn.utils.validation import (
_check_sample_weight,
check_is_fitted,
check_random_state,
validate_data,
)

from .._lib.sklearn.ensemble._forest import BaseForest
from .._lib.sklearn.tree._tree import DTYPE
Expand Down Expand Up @@ -88,7 +93,8 @@ def fit(self, X, y=None, sample_weight=None):
self._validate_params()

# Validate or convert input data
X = self._validate_data(
X = validate_data(
self,
X,
dtype=DTYPE, # accept_sparse="csc",
)
Expand Down
58 changes: 52 additions & 6 deletions treeple/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ def _assign_labels(self, affinity_matrix):
predict_labels = cluster.fit_predict(affinity_matrix)
return predict_labels

def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = False
return tags


class UnsupervisedObliqueDecisionTree(UnsupervisedDecisionTree):
"""Unsupervised oblique decision tree.
Expand Down Expand Up @@ -577,6 +584,13 @@ def _build_tree(
builder.build(self.tree_, X, sample_weight)
return self

def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = False
return tags


class ObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier):
"""An oblique decision tree classifier.
Expand Down Expand Up @@ -1070,6 +1084,13 @@ def _update_tree(self, X, y, sample_weight):
self._prune_tree()
return self

def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = False
return tags


class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""An oblique decision tree Regressor.
Expand Down Expand Up @@ -1450,6 +1471,13 @@ def _build_tree(
builder.build(self.tree_, X, y, sample_weight, None)
return self

def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = False
return tags


class PatchObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier):
"""A oblique decision tree classifier that operates over patches of data.
Expand Down Expand Up @@ -1927,11 +1955,13 @@ def _build_tree(

return self

def _more_tags(self):
def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
allow_nan = False
return {"multilabel": True, "allow_nan": allow_nan}
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_label = True
tags.input_tags.allow_nan = False
return tags

@property
def _inheritable_fitted_attribute(self):
Expand Down Expand Up @@ -2407,11 +2437,13 @@ def _build_tree(

return self

def _more_tags(self):
def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
allow_nan = False
return {"multilabel": True, "allow_nan": allow_nan}
tags = super().__sklearn_tags__()
tags.regressor_tags.multi_label = True
tags.input_tags.allow_nan = False
return tags


class ExtraObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier):
Expand Down Expand Up @@ -2846,6 +2878,13 @@ def _inheritable_fitted_attribute(self):
"feature_combinations_",
]

def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = False
return tags


class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""An oblique decision tree Regressor.
Expand Down Expand Up @@ -3237,3 +3276,10 @@ def _build_tree(
builder.build(self.tree_, X, y, sample_weight)

return self

def __sklearn_tags__(self):
# XXX: nans should be supportable in SPORF by just using RF-like splits on missing values
# However, for MORF it is not supported
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = False
return tags
4 changes: 0 additions & 4 deletions treeple/tree/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,3 @@ def compute_similarity_matrix(self, X):
The similarity matrix among the samples.
"""
return compute_forest_similarity_matrix(self, X)

def _more_tags(self):
# XXX: no treeple estimators support NaNs as of now
return {"allow_nan": False}

0 comments on commit 406abfd

Please sign in to comment.