From 55de725e47a2dac517635dce605c460e9c952109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Tue, 23 Mar 2021 18:47:13 +0100 Subject: [PATCH] assert integer in unique_rows --- README.md | 4 ++-- npx/main.py | 3 +++ test/test_npx.py | 5 +++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a3933a7..a5d324b 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,8 @@ to PR here. * ```python npx.unique_rows(a, return_inverse=False, return_counts=False) ``` - Returns the unique rows of the array `a`. The numpy alternative `np.unique(a, axis=0)` - [is slow](https://github.com/numpy/numpy/issues/11136). + Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a, + axis=0)` [is slow](https://github.com/numpy/numpy/issues/11136). ### License npx is published under the [MIT license](https://en.wikipedia.org/wiki/MIT_License). diff --git a/npx/main.py b/npx/main.py index cf11838..0fbfe99 100644 --- a/npx/main.py +++ b/npx/main.py @@ -84,6 +84,9 @@ def unique_rows(a, return_inverse=False, return_counts=False): # The numpy alternative `np.unique(a, axis=0)` is slow; cf. # . a = np.asarray(a) + if not np.issubdtype(a.dtype, np.integer): + raise ValueError(f"Input array must be integer type, got {a.dtype}.") + b = np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1]))) out = np.unique(b, return_inverse=return_inverse, return_counts=return_counts) # out[0] are the sorted, unique rows diff --git a/test/test_npx.py b/test/test_npx.py index d7a3115..5611e57 100644 --- a/test/test_npx.py +++ b/test/test_npx.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import npx @@ -60,3 +61,7 @@ def test_unique_rows(): a_unique = npx.unique_rows(a) assert np.all(a_unique == [[1, 2], [1, 4]]) + + a = [1.1, 2.2] + with pytest.raises(ValueError): + a_unique = npx.unique_rows(a)