Skip to content

Commit

Permalink
assert integer in unique_rows
Browse files Browse the repository at this point in the history
  • Loading branch information
nschloe committed Mar 23, 2021
1 parent 763af37 commit 55de725
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ to PR here.
* ```python
npx.unique_rows(a, return_inverse=False, return_counts=False)
```
Returns the unique rows of the array `a`. The numpy alternative `np.unique(a, axis=0)`
[is slow](https://github.com/numpy/numpy/issues/11136).
Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a,
axis=0)` [is slow](https://github.com/numpy/numpy/issues/11136).

### License
npx is published under the [MIT license](https://en.wikipedia.org/wiki/MIT_License).
3 changes: 3 additions & 0 deletions npx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def unique_rows(a, return_inverse=False, return_counts=False):
# The numpy alternative `np.unique(a, axis=0)` is slow; cf.
# <https://github.com/numpy/numpy/issues/11136>.
a = np.asarray(a)
if not np.issubdtype(a.dtype, np.integer):
raise ValueError(f"Input array must be integer type, got {a.dtype}.")

b = np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
out = np.unique(b, return_inverse=return_inverse, return_counts=return_counts)
# out[0] are the sorted, unique rows
Expand Down
5 changes: 5 additions & 0 deletions test/test_npx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

import npx

Expand Down Expand Up @@ -60,3 +61,7 @@ def test_unique_rows():

a_unique = npx.unique_rows(a)
assert np.all(a_unique == [[1, 2], [1, 4]])

a = [1.1, 2.2]
with pytest.raises(ValueError):
a_unique = npx.unique_rows(a)

0 comments on commit 55de725

Please sign in to comment.