Skip to content

Commit

Permalink
Merge pull request #5 from nschloe/unique-rows-1dim
Browse files Browse the repository at this point in the history
unique_rows for 1D
  • Loading branch information
nschloe authored Apr 16, 2021
2 parents f212961 + 3ad9de3 commit 6de70e4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
7 changes: 4 additions & 3 deletions npx/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ def unique_rows(a, return_inverse=False, return_counts=False):
if not np.issubdtype(a.dtype, np.integer):
raise ValueError(f"Input array must be integer type, got {a.dtype}.")

b = np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
p = np.prod(a.shape[1:])
b = np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * p)))
out = np.unique(b, return_inverse=return_inverse, return_counts=return_counts)
# out[0] are the sorted, unique rows
if return_inverse or return_counts:
out = (out[0].view(a.dtype).reshape(-1, a.shape[1]), *out[1:])
out = (out[0].view(a.dtype).reshape(-1, *a.shape[1:]), *out[1:])
else:
out = out.view(a.dtype).reshape(-1, a.shape[1])
out = out.view(a.dtype).reshape(-1, *a.shape[1:])

return out
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.4
version = 0.0.5
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand All @@ -12,7 +12,6 @@ project_urls =
long_description = file: README.md
long_description_content_type = text/markdown
license = MIT
license_file = LICENSE
# See <https://pypi.org/classifiers/> for all classifiers.
classifiers =
Development Status :: 4 - Beta
Expand Down
12 changes: 12 additions & 0 deletions test/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def test_subtract_at():


def test_unique_rows():
a = [1, 2, 1]
a_unique = npx.unique_rows(a)
assert np.all(a_unique == [1, 2])

a = [[1, 2], [1, 4], [1, 2]]
a_unique, inv, count = npx.unique_rows(a, return_inverse=True, return_counts=True)
assert np.all(a_unique == [[1, 2], [1, 4]])
Expand All @@ -62,6 +66,14 @@ def test_unique_rows():
a_unique = npx.unique_rows(a)
assert np.all(a_unique == [[1, 2], [1, 4]])

# entries are matrices
# fails for some reason. keep an eye on
# <https://stackoverflow.com/q/67128631/353337>
# a = [[[3, 4], [-1, 2]], [[3, 4], [-1, 2]]]
# a_unique = npx.unique_rows(a)
# print(a_unique)
# assert np.all(a_unique == [[[3, 4], [-1, 2]]])

a = [1.1, 2.2]
with pytest.raises(ValueError):
a_unique = npx.unique_rows(a)

0 comments on commit 6de70e4

Please sign in to comment.