Skip to content

Commit

Permalink
FIX proper inheritance for SGDOneClassSVM (scikit-learn#30227)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Nov 5, 2024
1 parent 70aab36 commit 102663d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :class:`~sklearn.linear_model.SGDOneClassSVM` now correctly inherits from
:class:`~sklearn.base.OutlierMixin` and the tags are correctly set.
By :user:`Guillaume Lemaitre <glemaitre>`
2 changes: 1 addition & 1 deletion sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ def __sklearn_tags__(self):
return tags


class SGDOneClassSVM(BaseSGD, OutlierMixin):
class SGDOneClassSVM(OutlierMixin, BaseSGD):
"""Solves linear One-Class SVM using Stochastic Gradient Descent.
This implementation is meant to be used with a kernel approximation
Expand Down
10 changes: 10 additions & 0 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler, scale
from sklearn.svm import OneClassSVM
from sklearn.utils import get_tags
from sklearn.utils._testing import (
assert_allclose,
assert_almost_equal,
Expand Down Expand Up @@ -2170,3 +2171,12 @@ def test_passive_aggressive_deprecated_average(Estimator):
est = Estimator(average=0)
with pytest.warns(FutureWarning, match="average=0"):
est.fit(X, Y)


def test_sgd_one_class_svm_estimator_type():
"""Check that SGDOneClassSVM has the correct estimator type.
Non-regression test for if the mixin was not on the left.
"""
sgd_ocsvm = SGDOneClassSVM()
assert get_tags(sgd_ocsvm).estimator_type == "outlier_detector"

0 comments on commit 102663d

Please sign in to comment.