Skip to content

Commit

Permalink
MNT: Make error message clearer for n_neighbors (scikit-learn#23317)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Tom Dupré la Tour <tom.duprelatour.10@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
5 people authored Nov 18, 2023
1 parent 9e8f10e commit 5c4288d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ Changelog
when `radius` is large and `algorithm="brute"` with non-Euclidean metrics.
:pr:`26828` by :user:`Omar Salman <OmarManzoor>`.

- |Fix| Improve error message for :class:`neighbors.LocalOutlierFactor`
when it is invoked with `n_samples = n_neighbors`.
:pr:`23317` by :user:`Bharat Raghunathan <Bharat123rox>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
10 changes: 8 additions & 2 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,15 @@ class from an array representing our data set and ask who's

n_samples_fit = self.n_samples_fit_
if n_neighbors > n_samples_fit:
if query_is_train:
n_neighbors -= 1 # ok to modify inplace because an error is raised
inequality_str = "n_neighbors < n_samples_fit"
else:
inequality_str = "n_neighbors <= n_samples_fit"
raise ValueError(
"Expected n_neighbors <= n_samples, "
" but n_samples = %d, n_neighbors = %d" % (n_samples_fit, n_neighbors)
f"Expected {inequality_str}, but "
f"n_neighbors = {n_neighbors}, n_samples_fit = {n_samples_fit}, "
f"n_samples = {X.shape[0]}" # include n_samples for common tests
)

n_jobs = effective_n_jobs(self.n_jobs)
Expand Down
44 changes: 44 additions & 0 deletions sklearn/neighbors/tests/test_lof.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,50 @@ def test_sparse(csr_container):
lof.fit_predict(X)


def test_lof_error_n_neighbors_too_large():
"""Check that we raise a proper error message when n_neighbors == n_samples.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/17207
"""
X = np.ones((7, 7))

msg = (
"Expected n_neighbors < n_samples_fit, but n_neighbors = 1, "
"n_samples_fit = 1, n_samples = 1"
)
with pytest.raises(ValueError, match=msg):
lof = neighbors.LocalOutlierFactor(n_neighbors=1).fit(X[:1])

lof = neighbors.LocalOutlierFactor(n_neighbors=2).fit(X[:2])
assert lof.n_samples_fit_ == 2

msg = (
"Expected n_neighbors < n_samples_fit, but n_neighbors = 2, "
"n_samples_fit = 2, n_samples = 2"
)
with pytest.raises(ValueError, match=msg):
lof.kneighbors(None, n_neighbors=2)

distances, indices = lof.kneighbors(None, n_neighbors=1)
assert distances.shape == (2, 1)
assert indices.shape == (2, 1)

msg = (
"Expected n_neighbors <= n_samples_fit, but n_neighbors = 3, "
"n_samples_fit = 2, n_samples = 7"
)
with pytest.raises(ValueError, match=msg):
lof.kneighbors(X, n_neighbors=3)

(
distances,
indices,
) = lof.kneighbors(X, n_neighbors=2)
assert distances.shape == (7, 2)
assert indices.shape == (7, 2)


@pytest.mark.parametrize("algorithm", ["auto", "ball_tree", "kd_tree", "brute"])
@pytest.mark.parametrize("novelty", [True, False])
@pytest.mark.parametrize("contamination", [0.5, "auto"])
Expand Down

0 comments on commit 5c4288d

Please sign in to comment.