diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index ce16556feb88c..507b22cdf510a 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -474,6 +474,10 @@ Changelog when `radius` is large and `algorithm="brute"` with non-Euclidean metrics. :pr:`26828` by :user:`Omar Salman `. +- |Fix| Improve error message for :class:`neighbors.LocalOutlierFactor` + when it is invoked with `n_samples = n_neighbors`. + :pr:`23317` by :user:`Bharat Raghunathan `. + :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 519db9bead3d3..848c8b7c9dc5a 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -813,9 +813,15 @@ class from an array representing our data set and ask who's n_samples_fit = self.n_samples_fit_ if n_neighbors > n_samples_fit: + if query_is_train: + n_neighbors -= 1 # ok to modify inplace because an error is raised + inequality_str = "n_neighbors < n_samples_fit" + else: + inequality_str = "n_neighbors <= n_samples_fit" raise ValueError( - "Expected n_neighbors <= n_samples, " - " but n_samples = %d, n_neighbors = %d" % (n_samples_fit, n_neighbors) + f"Expected {inequality_str}, but " + f"n_neighbors = {n_neighbors}, n_samples_fit = {n_samples_fit}, " + f"n_samples = {X.shape[0]}" # include n_samples for common tests ) n_jobs = effective_n_jobs(self.n_jobs) diff --git a/sklearn/neighbors/tests/test_lof.py b/sklearn/neighbors/tests/test_lof.py index 221d78243915f..7233beddafe9c 100644 --- a/sklearn/neighbors/tests/test_lof.py +++ b/sklearn/neighbors/tests/test_lof.py @@ -255,6 +255,50 @@ def test_sparse(csr_container): lof.fit_predict(X) +def test_lof_error_n_neighbors_too_large(): + """Check that we raise a proper error message when n_neighbors == n_samples. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/17207 + """ + X = np.ones((7, 7)) + + msg = ( + "Expected n_neighbors < n_samples_fit, but n_neighbors = 1, " + "n_samples_fit = 1, n_samples = 1" + ) + with pytest.raises(ValueError, match=msg): + lof = neighbors.LocalOutlierFactor(n_neighbors=1).fit(X[:1]) + + lof = neighbors.LocalOutlierFactor(n_neighbors=2).fit(X[:2]) + assert lof.n_samples_fit_ == 2 + + msg = ( + "Expected n_neighbors < n_samples_fit, but n_neighbors = 2, " + "n_samples_fit = 2, n_samples = 2" + ) + with pytest.raises(ValueError, match=msg): + lof.kneighbors(None, n_neighbors=2) + + distances, indices = lof.kneighbors(None, n_neighbors=1) + assert distances.shape == (2, 1) + assert indices.shape == (2, 1) + + msg = ( + "Expected n_neighbors <= n_samples_fit, but n_neighbors = 3, " + "n_samples_fit = 2, n_samples = 7" + ) + with pytest.raises(ValueError, match=msg): + lof.kneighbors(X, n_neighbors=3) + + ( + distances, + indices, + ) = lof.kneighbors(X, n_neighbors=2) + assert distances.shape == (7, 2) + assert indices.shape == (7, 2) + + @pytest.mark.parametrize("algorithm", ["auto", "ball_tree", "kd_tree", "brute"]) @pytest.mark.parametrize("novelty", [True, False]) @pytest.mark.parametrize("contamination", [0.5, "auto"])