Skip to content

Commit

Permalink
Update submodule
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Sep 9, 2024
1 parent be2b655 commit cd793cd
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion treeple/_lib/sklearn_fork
12 changes: 6 additions & 6 deletions treeple/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# from .._lib.sklearn.tree import (
# DecisionTreeClassifier,
# DecisionTreeRegressor,
# ExtraTreeClassifier,
# ExtraTreeRegressor,
# )
from .._lib.sklearn.tree import (
DecisionTreeClassifier,
DecisionTreeRegressor,
ExtraTreeClassifier,
ExtraTreeRegressor,
)
from ._classes import (
ExtraObliqueDecisionTreeClassifier,
ExtraObliqueDecisionTreeRegressor,
Expand Down
12 changes: 6 additions & 6 deletions treeple/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.cluster import AgglomerativeClustering
from sklearn.utils import check_random_state
from sklearn.utils._param_validation import Interval
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_is_fitted, validate_data

from .._lib.sklearn.tree import (
BaseDecisionTree,
Expand Down Expand Up @@ -216,7 +216,7 @@ def fit(self, X, y=None, sample_weight=None, check_input=True):
if check_input:
# TODO: allow X to be sparse
check_X_params = dict(dtype=DTYPE) # , accept_sparse="csc"
X = self._validate_data(X, validate_separately=(check_X_params))
X = validate_data(self, X, validate_separately=(check_X_params))
if issparse(X):
X.sort_indices()

Expand Down Expand Up @@ -1798,8 +1798,8 @@ def _build_tree(
self.feature_combinations_ = 1

if self.feature_weight is not None:
self.feature_weight = self._validate_data(
self.feature_weight, ensure_2d=True, dtype=DTYPE
self.feature_weight = validate_data(
self, self.feature_weight, ensure_2d=True, dtype=DTYPE
)
if self.feature_weight.shape != X.shape:
raise ValueError(
Expand Down Expand Up @@ -2277,8 +2277,8 @@ def _build_tree(
self.feature_combinations_ = 1

if self.feature_weight is not None:
self.feature_weight = self._validate_data(
self.feature_weight, ensure_2d=True, dtype=DTYPE
self.feature_weight = validate_data(
self, self.feature_weight, ensure_2d=True, dtype=DTYPE
)
if self.feature_weight.shape != X.shape:
raise ValueError(
Expand Down

0 comments on commit cd793cd

Please sign in to comment.