Skip to content

Commit

Permalink
Merge pull request #3 from nschloe/krylov
Browse files Browse the repository at this point in the history
Krylov
  • Loading branch information
nschloe authored Mar 25, 2021
2 parents 620253f + 81ce22f commit 78ab4a7
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 6 deletions.
18 changes: 14 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
[![LGTM](https://img.shields.io/lgtm/grade/python/github/nschloe/npx.svg?style=flat-square)](https://lgtm.com/projects/g/nschloe/npx)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/psf/black)

[NumPy](https://numpy.org/) is a large library used everywhere in scientific computing.
That's why breaking backwards-compatibility is comes as a significant cost and is almost
always avoided, even if the API of some methods is arguably confusing. This package
provides drop-in wrappers "fixing" those.
[NumPy](https://numpy.org/) and [SciPy](https://www.scipy.org/) are large libraries used
everywhere in scientific computing. That's why breaking backwards-compatibility comes as
a significant cost and is almost always avoided, even if the API of some methods is
arguably lacking. This package provides drop-in wrappers "fixing" those.

If you have a fix for a NumPy method that can't go upstream for some reason, feel free
to PR here.
Expand Down Expand Up @@ -52,5 +52,15 @@ to PR here.
Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a,
axis=0)` [is slow](https://github.com/numpy/numpy/issues/11136).

* ```python
sol, info = npx.cg(A, b, tol=1.0e-10)
sol, info = npx.minres(A, b, tol=1.0e-10)
sol, info = npx.gmres(A, b, tol=1.0e-10)
```
`sol` is the solution of the linear system `A @ x = b` (or `None` if no convergence),
and `info` contains some useful data, e.g., `info.resnorms`. The methods are wrappers
around [SciPy's iterative
solvers](https://docs.scipy.org/doc/scipy/reference/sparse.linalg.html).

### License
npx is published under the [MIT license](https://en.wikipedia.org/wiki/MIT_License).
6 changes: 5 additions & 1 deletion npx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .__about__ import __version__
from .main import add_at, dot, solve, subtract_at, sum_at, unique_rows
from ._krylov import cg, gmres, minres
from ._main import add_at, dot, solve, subtract_at, sum_at, unique_rows

__all__ = [
"__version__",
Expand All @@ -9,4 +10,7 @@
"add_at",
"subtract_at",
"unique_rows",
"cg",
"gmres",
"minres",
]
153 changes: 153 additions & 0 deletions npx/_krylov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from collections import namedtuple
from typing import Optional

import numpy as np
import scipy
import scipy.sparse.linalg

Info = namedtuple("KrylovInfo", ["success", "xk", "resnorms", "errnorms"])


def cg(
A,
b,
x0=None,
tol: float = 1e-05,
maxiter: Optional[int] = None,
M=None,
callback=None,
atol: Optional[float] = None,
exact_solution=None,
):
resnorms = []

if exact_solution is None:
errnorms = None
else:
err = exact_solution - x0
errnorms = [np.sqrt(np.dot(err, err))]

def cb(xk):
if callback is not None:
callback(xk)

res = b - A @ xk
if M is not None:
res = M @ res
resnorms.append(np.sqrt(np.dot(res, res)))

if exact_solution is not None:
err = exact_solution - x0
errnorms.append(np.sqrt(np.dot(err, err)))

x, info = scipy.sparse.linalg.cg(
A, b, x0=x0, tol=tol, maxiter=maxiter, M=M, atol=atol, callback=cb
)

success = info == 0

resnorms = np.array(resnorms)
if errnorms is not None:
errnorms = np.array(errnorms)

return x if success else None, Info(success, x, resnorms, errnorms)


def gmres(
A,
b,
x0=None,
tol: float = 1e-05,
restart: Optional[int] = None,
maxiter: Optional[int] = None,
M=None,
callback=None,
atol: Optional[float] = None,
exact_solution=None,
):
resnorms = []

if exact_solution is None:
errnorms = None
else:
err = exact_solution - x0
errnorms = [np.sqrt(np.dot(err, err))]

def cb(xk):
if callback is not None:
callback(xk)

res = b - A @ xk
if M is not None:
res = M @ res
resnorms.append(np.sqrt(np.dot(res, res)))

if exact_solution is not None:
err = exact_solution - x0
errnorms.append(np.sqrt(np.dot(err, err)))

x, info = scipy.sparse.linalg.gmres(
A,
b,
x0=x0,
tol=tol,
restart=restart,
maxiter=maxiter,
M=M,
atol=atol,
callback=cb,
callback_type="x",
)

success = info == 0

resnorms = np.array(resnorms)
if errnorms is not None:
errnorms = np.array(errnorms)

return x if success else None, Info(success, x, resnorms, errnorms)


def minres(
A,
b,
x0=None,
shift: float = 0.0,
tol: float = 1e-05,
maxiter: Optional[int] = None,
M=None,
callback=None,
exact_solution=None,
):
resnorms = []

if exact_solution is None:
errnorms = None
else:
err = exact_solution - x0
errnorms = [np.sqrt(np.dot(err, err))]

def cb(xk):
if callback is not None:
callback(xk)

res = b - A @ xk
if M is not None:
res = M @ res
resnorms.append(np.sqrt(np.dot(res, res)))

if exact_solution is not None:
err = exact_solution - x0
errnorms.append(np.sqrt(np.dot(err, err)))

x, info = scipy.sparse.linalg.minres(
A, b, x0=x0, shift=shift, tol=tol, maxiter=maxiter, M=M, callback=cb
)

success = info == 0

resnorms = np.array(resnorms)
if errnorms is not None:
errnorms = np.array(errnorms)

return x if success else None, Info(success, x, resnorms, errnorms)
File renamed without changes.
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.2
version = 0.0.3
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand Down Expand Up @@ -33,6 +33,7 @@ packages = find:
install_requires =
importlib_metadata;python_version<"3.8"
numpy
scipy
python_requires = >=3.6

[options.entry_points]
Expand Down
85 changes: 85 additions & 0 deletions test/test_krylov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import scipy.sparse.linalg

import npx


def _run(fun, resnorms1, resnorms2, tol=1.0e-13):
n = 10
data = -np.ones((3, n))
data[1] = 2.0
A = scipy.sparse.spdiags(data, [-1, 0, 1], n, n)
A = A.tocsr()
b = np.ones(n)

sol, info = fun(A, b)
assert sol is not None
assert info.success
resnorms1 = np.asarray(resnorms1)
for x in info.resnorms:
print(f"{x:.15e}")
print()
assert np.all(np.abs(info.resnorms - resnorms1) < tol * (1 + resnorms1))

# with "preconditioning"
M = scipy.sparse.linalg.LinearOperator((n, n), matvec=lambda x: 0.5 * x)
sol, info = fun(A, b, M=M)

assert sol is not None
assert info.success
resnorms2 = np.asarray(resnorms2)
for x in info.resnorms:
print(f"{x:.15e}")
assert np.all(np.abs(info.resnorms - resnorms2) < tol * (1 + resnorms2))


def test_cg():
_run(
npx.cg,
[
6.324555320336759e00,
4.898979485566356e00,
3.464101615137754e00,
2.000000000000000e00,
0.000000000000000e00,
],
[
3.162277660168380e00,
2.449489742783178e00,
1.732050807568877e00,
1.000000000000000e00,
0.000000000000000e00,
],
)


def test_gmres():
_run(
npx.gmres,
[3.162277660168380e00, 7.160723346098895e-15],
[1.581138830084190e00, 3.580361673049448e-15],
)


def test_minres():
_run(
npx.minres,
[
2.828427124746190e00,
2.449489742783178e00,
2.000000000000000e00,
1.414213562373095e00,
8.747542958250513e-15,
],
[
1.414213562373095e00,
1.224744871391589e00,
1.000000000000000e00,
7.071067811865476e-01,
3.871479975306501e-15,
],
)


if __name__ == "__main__":
test_gmres()

0 comments on commit 78ab4a7

Please sign in to comment.