Skip to content

Commit

Permalink
Merge pull request #29 from nschloe/outer
Browse files Browse the repository at this point in the history
add npx.outer
  • Loading branch information
nschloe authored Jan 16, 2022
2 parents cda7d8e + 5e25310 commit 3dc43cb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.24
version = 0.0.25
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand Down
3 changes: 2 additions & 1 deletion src/npx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from ._isin import isin_rows
from ._main import add_at, dot, solve, subtract_at, sum_at
from ._main import add_at, dot, outer, solve, subtract_at, sum_at
from ._mean import mean
from ._unique import unique, unique_rows

__all__ = [
"dot",
"outer",
"solve",
"sum_at",
"add_at",
Expand Down
9 changes: 9 additions & 0 deletions src/npx/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def dot(a: ArrayLike, b: ArrayLike) -> np.ndarray:
return np.dot(a, b.reshape(b.shape[0], -1)).reshape(a.shape[:-1] + b.shape[1:])


def outer(a: ArrayLike, b: ArrayLike) -> np.ndarray:
"""Compute the outer product of two arrays `a` and `b` such that the shape
of the resulting array is `(*a.shape, *b.shape)`.
"""
a = np.asarray(a)
b = np.asarray(b)
return np.outer(a, b).reshape(*a.shape, *b.shape)


def solve(A: ArrayLike, x: ArrayLike) -> np.ndarray:
"""Solves a linear equation system with a matrix of shape (n, n) and an array of
shape (n, ...). The output has the same shape as the second argument.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dot_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ def test_solve():
b = np.random.rand(3, 4, 5)
c = npx.solve(a, b)
assert c.shape == b.shape


def test_outer():
a = np.random.rand(1, 2)
b = np.random.rand(3, 4)
c = npx.outer(a, b)
assert c.shape == (1, 2, 3, 4)

0 comments on commit 3dc43cb

Please sign in to comment.