Skip to content

Commit

Permalink
Merge pull request #24 from nschloe/bincount-workaround
Browse files Browse the repository at this point in the history
bincount workaround
  • Loading branch information
nschloe authored Sep 30, 2021
2 parents f9dbe4d + 7f27c55 commit 53476d6
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 25 deletions.
28 changes: 14 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ name: ci
on:
push:
branches:
- main
- main
pull_request:
branches:
- main
- main

jobs:
lint:
Expand All @@ -23,16 +23,16 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ["3.7", "3.8", "3.9", "3.10-dev"]
steps:
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/checkout@v2
- name: Test with tox
run: |
pip install tox
tox -- --cov npx --cov-report xml --cov-report term
- name: Submit to codecov
uses: codecov/codecov-action@v1
if: ${{ matrix.python-version == '3.9' }}
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/checkout@v2
- name: Test with tox
run: |
pip install tox
tox -- --cov npx --cov-report xml --cov-report term
- name: Submit to codecov
uses: codecov/codecov-action@v1
if: ${{ matrix.python-version == '3.9' }}
29 changes: 19 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ provides drop-in wrappers "fixing" those.
If you have a fix for a NumPy method that can't go upstream for some reason, feel free
to PR here.


#### `dot`

```python
import npx
import numpy as np
Expand All @@ -33,12 +33,13 @@ b = np.random.rand(5, 2, 2)
out = npx.dot(a, b)
# out.shape == (3, 4, 2, 2)
```

Forms the dot product between the last axis of `a` and the _first_ axis of `b`.

(Not the second-last axis of `b` as `numpy.dot(a, b)`.)


#### `np.solve`

```python
import npx
import numpy as np
Expand All @@ -49,16 +50,19 @@ b = np.random.rand(3, 10, 4)
out = npx.solve(A, b)
# out.shape == (3, 10, 4)
```

Solves a linear equation system with a matrix of shape `(n, n)` and an array of shape
`(n, ...)`. The output has the same shape as the second argument.


#### `sum_at`/`add_at`

<!--pytest-codeblocks:skip-->

```python
npx.sum_at(a, idx, minlength=0)
npx.add_at(out, idx, a)
```

Returns an array with entries of `a` summed up at indices `idx` with a minumum length of
`minlength`. `idx` can have any shape as long as it's matching `a`. The output shape is
`(minlength,...)`.
Expand All @@ -69,11 +73,12 @@ slower:
<img alt="memory usage" src="https://nschloe.github.io/npx/perf-add-at.svg" width="50%">

Relevant issue reports:
* [ufunc.at (and possibly other methods)
slow](https://github.com/numpy/numpy/issues/11156)

- [ufunc.at (and possibly other methods)
slow](https://github.com/numpy/numpy/issues/11156)

#### `unique_rows`

```python
import npx
import numpy as np
Expand All @@ -82,14 +87,15 @@ a = np.random.randint(0, 5, size=(100, 2))

npx.unique_rows(a, return_inverse=False, return_counts=False)
```
Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a,
axis=0)` is slow.

Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a, axis=0)` is slow.

Relevant issue reports:
* [unique() needlessly slow](https://github.com/numpy/numpy/issues/11136)

- [unique() needlessly slow](https://github.com/numpy/numpy/issues/11136)

#### `isin_rows`

```python
import npx
import numpy as np
Expand All @@ -99,25 +105,28 @@ b = np.random.randint(0, 5, size=(100, 2))

npx.isin_rows(a, b)
```

Returns a boolean array of length `len(a)` specifying if the rows `a[k]` appear in `b`.
Similar to NumPy's own `np.isin` which only works for scalars.


#### `mean`

```python
import npx

a = [1.0, 2.0, 5.0]
npx.mean(a, p=3)
```

Returns the [generalized mean](https://en.wikipedia.org/wiki/Generalized_mean) of a
given list. Handles the cases `+-np.inf` (max/min) and`0` (geometric mean) correctly.
Also does well for large `p`.

Relevant NumPy issues:
* [generalized mean](https://github.com/numpy/numpy/issues/19341)

- [generalized mean](https://github.com/numpy/numpy/issues/19341)

### License

This software is published under the [BSD-3-Clause
license](https://spdx.org/licenses/BSD-3-Clause.html).
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.19
version = 0.0.20
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand All @@ -23,6 +23,7 @@ classifiers =
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Scientific/Engineering
Topic :: Utilities

Expand Down
4 changes: 4 additions & 0 deletions src/npx/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def sum_at(a: npt.ArrayLike, indices: npt.ArrayLike, minlength: int):
indices = indices.reshape(-1)
a = a.reshape(_prod(a.shape[:m]), _prod(a.shape[m:]))

# Cast to int; bincount doesn't work for uint64 yet
# https://github.com/numpy/numpy/issues/17760
indices = indices.astype(int)

return np.array(
[
np.bincount(indices, weights=a[:, k], minlength=minlength)
Expand Down

0 comments on commit 53476d6

Please sign in to comment.