diff --git a/README.md b/README.md index a3933a7..a5d324b 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,8 @@ to PR here. * ```python npx.unique_rows(a, return_inverse=False, return_counts=False) ``` - Returns the unique rows of the array `a`. The numpy alternative `np.unique(a, axis=0)` - [is slow](https://github.com/numpy/numpy/issues/11136). + Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a, + axis=0)` [is slow](https://github.com/numpy/numpy/issues/11136). ### License npx is published under the [MIT license](https://en.wikipedia.org/wiki/MIT_License). diff --git a/npx/main.py b/npx/main.py index cf11838..0fbfe99 100644 --- a/npx/main.py +++ b/npx/main.py @@ -84,6 +84,9 @@ def unique_rows(a, return_inverse=False, return_counts=False): # The numpy alternative `np.unique(a, axis=0)` is slow; cf. # . a = np.asarray(a) + if not np.issubdtype(a.dtype, np.integer): + raise ValueError(f"Input array must be integer type, got {a.dtype}.") + b = np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1]))) out = np.unique(b, return_inverse=return_inverse, return_counts=return_counts) # out[0] are the sorted, unique rows diff --git a/test/test_npx.py b/test/test_npx.py index d7a3115..5611e57 100644 --- a/test/test_npx.py +++ b/test/test_npx.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import npx @@ -60,3 +61,7 @@ def test_unique_rows(): a_unique = npx.unique_rows(a) assert np.all(a_unique == [[1, 2], [1, 4]]) + + a = [1.1, 2.2] + with pytest.raises(ValueError): + a_unique = npx.unique_rows(a)