From 3155f4a562b1ffead6f4141cd7d39330413cf0a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Thu, 25 Mar 2021 22:00:39 +0100 Subject: [PATCH 1/6] add cg --- README.md | 8 +++--- npx/__init__.py | 4 ++- npx/_krylov.py | 53 +++++++++++++++++++++++++++++++++++++++ npx/{main.py => _main.py} | 0 test/test_krylov.py | 50 ++++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 npx/_krylov.py rename npx/{main.py => _main.py} (100%) create mode 100644 test/test_krylov.py diff --git a/README.md b/README.md index a5d324b..33dc139 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,10 @@ [![LGTM](https://img.shields.io/lgtm/grade/python/github/nschloe/npx.svg?style=flat-square)](https://lgtm.com/projects/g/nschloe/npx) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/psf/black) -[NumPy](https://numpy.org/) is a large library used everywhere in scientific computing. -That's why breaking backwards-compatibility is comes as a significant cost and is almost -always avoided, even if the API of some methods is arguably confusing. This package -provides drop-in wrappers "fixing" those. +[NumPy](https://numpy.org/) and [SciPy](https://www.scipy.org/) are large libraries used +everywhere in scientific computing. That's why breaking backwards-compatibility comes as +a significant cost and is almost always avoided, even if the API of some methods is +arguably lacking. This package provides drop-in wrappers "fixing" those. If you have a fix for a NumPy method that can't go upstream for some reason, feel free to PR here. diff --git a/npx/__init__.py b/npx/__init__.py index d0dd4ba..b448cf0 100644 --- a/npx/__init__.py +++ b/npx/__init__.py @@ -1,5 +1,6 @@ from .__about__ import __version__ -from .main import add_at, dot, solve, subtract_at, sum_at, unique_rows +from ._krylov import cg +from ._main import add_at, dot, solve, subtract_at, sum_at, unique_rows __all__ = [ "__version__", @@ -9,4 +10,5 @@ "add_at", "subtract_at", "unique_rows", + "cg", ] diff --git a/npx/_krylov.py b/npx/_krylov.py new file mode 100644 index 0000000..c7e33e7 --- /dev/null +++ b/npx/_krylov.py @@ -0,0 +1,53 @@ +from collections import namedtuple +from typing import Optional + +import numpy as np +import scipy +import scipy.sparse.linalg + +Info = namedtuple("KrylovInfo", ["success", "xk", "resnorms", "errnorms"]) + + +def cg( + A, + b, + x0=None, + tol: float = 1e-05, + maxiter: Optional[int] = None, + M=None, + callback=None, + atol: Optional[float] = None, + exact_solution=None, +): + resnorms = [] + + if exact_solution is None: + errnorms = None + else: + err = exact_solution - x0 + errnorms = [np.sqrt(np.dot(err, err))] + + def cb(xk): + if callback is not None: + callback(xk) + + res = b - A @ xk + if M is not None: + res = M @ res + resnorms.append(np.sqrt(np.dot(res, res))) + + if exact_solution is not None: + err = exact_solution - x0 + errnorms.append(np.sqrt(np.dot(err, err))) + + x, info = scipy.sparse.linalg.cg( + A, b, x0=x0, tol=tol, maxiter=maxiter, M=M, atol=atol, callback=cb + ) + + success = info == 0 + + resnorms = np.array(resnorms) + if errnorms is not None: + errnorms = np.array(errnorms) + + return x if success else None, Info(success, x, resnorms, errnorms) diff --git a/npx/main.py b/npx/_main.py similarity index 100% rename from npx/main.py rename to npx/_main.py diff --git a/test/test_krylov.py b/test/test_krylov.py new file mode 100644 index 0000000..bfd4419 --- /dev/null +++ b/test/test_krylov.py @@ -0,0 +1,50 @@ +import numpy as np +import scipy.sparse.linalg + +import npx + + +def test_cg(): + n = 10 + data = -np.ones((3, n)) + data[1] = 2.0 + A = scipy.sparse.spdiags(data, [-1, 0, 1], n, n) + A = A.tocsr() + b = np.ones(n) + + sol, info = npx.cg(A, b) + assert sol is not None + assert info.success + ref = np.array( + [ + 6.324555320336759e00, + 4.898979485566356e00, + 3.464101615137754e00, + 2.000000000000000e00, + 0.000000000000000e00, + ] + ) + tol = 1.0e-13 + assert np.all(np.abs(info.resnorms - ref) < tol * (1 + ref)) + + # with "preconditioning" + M = scipy.sparse.linalg.LinearOperator((n, n), matvec=lambda x: 0.5 * x) + sol, info = npx.cg(A, b, M=M) + + assert sol is not None + assert info.success + ref = np.array( + [ + 3.162277660168380e00, + 2.449489742783178e00, + 1.732050807568877e00, + 1.000000000000000e00, + 0.000000000000000e00, + ] + ) + tol = 1.0e-13 + assert np.all(np.abs(info.resnorms - ref) < tol * (1 + ref)) + + +if __name__ == "__main__": + test_cg() From 267036910acaf2165aed0d390a56c3047a660d00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Thu, 25 Mar 2021 23:22:32 +0100 Subject: [PATCH 2/6] add gmres, minres --- npx/__init__.py | 4 +- npx/_krylov.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ test/test_krylov.py | 73 +++++++++++++++++++++++--------- 3 files changed, 157 insertions(+), 20 deletions(-) diff --git a/npx/__init__.py b/npx/__init__.py index b448cf0..b6afcfc 100644 --- a/npx/__init__.py +++ b/npx/__init__.py @@ -1,5 +1,5 @@ from .__about__ import __version__ -from ._krylov import cg +from ._krylov import cg, gmres, minres from ._main import add_at, dot, solve, subtract_at, sum_at, unique_rows __all__ = [ @@ -11,4 +11,6 @@ "subtract_at", "unique_rows", "cg", + "gmres", + "minres", ] diff --git a/npx/_krylov.py b/npx/_krylov.py index c7e33e7..e0cf2d9 100644 --- a/npx/_krylov.py +++ b/npx/_krylov.py @@ -51,3 +51,103 @@ def cb(xk): errnorms = np.array(errnorms) return x if success else None, Info(success, x, resnorms, errnorms) + + +def gmres( + A, + b, + x0=None, + tol: float = 1e-05, + restart: Optional[int] = None, + maxiter: Optional[int] = None, + M=None, + callback=None, + atol: Optional[float] = None, + exact_solution=None, +): + resnorms = [] + + if exact_solution is None: + errnorms = None + else: + err = exact_solution - x0 + errnorms = [np.sqrt(np.dot(err, err))] + + def cb(xk): + if callback is not None: + callback(xk) + + res = b - A @ xk + if M is not None: + res = M @ res + resnorms.append(np.sqrt(np.dot(res, res))) + + if exact_solution is not None: + err = exact_solution - x0 + errnorms.append(np.sqrt(np.dot(err, err))) + + x, info = scipy.sparse.linalg.gmres( + A, + b, + x0=x0, + tol=tol, + restart=restart, + maxiter=maxiter, + M=M, + atol=atol, + callback=cb, + callback_type="x", + ) + + success = info == 0 + + resnorms = np.array(resnorms) + if errnorms is not None: + errnorms = np.array(errnorms) + + return x if success else None, Info(success, x, resnorms, errnorms) + + +def minres( + A, + b, + x0=None, + shift: float = 0.0, + tol: float = 1e-05, + maxiter: Optional[int] = None, + M=None, + callback=None, + exact_solution=None, +): + resnorms = [] + + if exact_solution is None: + errnorms = None + else: + err = exact_solution - x0 + errnorms = [np.sqrt(np.dot(err, err))] + + def cb(xk): + if callback is not None: + callback(xk) + + res = b - A @ xk + if M is not None: + res = M @ res + resnorms.append(np.sqrt(np.dot(res, res))) + + if exact_solution is not None: + err = exact_solution - x0 + errnorms.append(np.sqrt(np.dot(err, err))) + + x, info = scipy.sparse.linalg.minres( + A, b, x0=x0, shift=shift, tol=tol, maxiter=maxiter, M=M, callback=cb + ) + + success = info == 0 + + resnorms = np.array(resnorms) + if errnorms is not None: + errnorms = np.array(errnorms) + + return x if success else None, Info(success, x, resnorms, errnorms) diff --git a/test/test_krylov.py b/test/test_krylov.py index bfd4419..56d0ba1 100644 --- a/test/test_krylov.py +++ b/test/test_krylov.py @@ -4,7 +4,7 @@ import npx -def test_cg(): +def _run(fun, resnorms1, resnorms2, tol=1.0e-13): n = 10 data = -np.ones((3, n)) data[1] = 2.0 @@ -12,39 +12,74 @@ def test_cg(): A = A.tocsr() b = np.ones(n) - sol, info = npx.cg(A, b) + sol, info = fun(A, b) assert sol is not None assert info.success - ref = np.array( - [ - 6.324555320336759e00, - 4.898979485566356e00, - 3.464101615137754e00, - 2.000000000000000e00, - 0.000000000000000e00, - ] - ) - tol = 1.0e-13 - assert np.all(np.abs(info.resnorms - ref) < tol * (1 + ref)) + resnorms1 = np.asarray(resnorms1) + for x in info.resnorms: + print(f"{x:.15e}") + print() + assert np.all(np.abs(info.resnorms - resnorms1) < tol * (1 + resnorms1)) # with "preconditioning" M = scipy.sparse.linalg.LinearOperator((n, n), matvec=lambda x: 0.5 * x) - sol, info = npx.cg(A, b, M=M) + sol, info = fun(A, b, M=M) assert sol is not None assert info.success - ref = np.array( + resnorms2 = np.asarray(resnorms2) + for x in info.resnorms: + print(f"{x:.15e}") + assert np.all(np.abs(info.resnorms - resnorms2) < tol * (1 + resnorms2)) + + +def test_cg(): + _run( + npx.cg, + [ + 6.324555320336759e00, + 4.898979485566356e00, + 3.464101615137754e00, + 2.000000000000000e00, + 0.000000000000000e00, + ], [ 3.162277660168380e00, 2.449489742783178e00, 1.732050807568877e00, 1.000000000000000e00, 0.000000000000000e00, - ] + ], + ) + + +def test_gmres(): + _run( + npx.gmres, + [3.162277660168380e00, 7.160723346098895e-15], + [1.581138830084190e00, 3.580361673049448e-15], + ) + + +def test_minres(): + _run( + npx.minres, + [ + 2.828427124746190e00, + 2.449489742783178e00, + 2.000000000000000e00, + 1.414213562373095e00, + 8.747542958250513e-15, + ], + [ + 1.414213562373095e00, + 1.224744871391589e00, + 1.000000000000000e00, + 7.071067811865476e-01, + 3.871479975306501e-15, + ], ) - tol = 1.0e-13 - assert np.all(np.abs(info.resnorms - ref) < tol * (1 + ref)) if __name__ == "__main__": - test_cg() + test_gmres() From 6d45348b9c89ec6f8332c523bc5a78e8299eb590 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Thu, 25 Mar 2021 23:26:13 +0100 Subject: [PATCH 3/6] version bump --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index bc12bdb..312f270 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = npx -version = 0.0.2 +version = 0.0.3 author = Nico Schlömer author_email = nico.schloemer@gmail.com description = Some useful extensions for NumPy From d88fd41397c5e3179d2f59530204f59708771f05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Thu, 25 Mar 2021 23:28:15 +0100 Subject: [PATCH 4/6] readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 33dc139..8e7b8f7 100644 --- a/README.md +++ b/README.md @@ -52,5 +52,13 @@ to PR here. Returns the unique rows of the integer array `a`. The numpy alternative `np.unique(a, axis=0)` [is slow](https://github.com/numpy/numpy/issues/11136). +* ```python + sol, info = npx.cg(A, b, tol=1.0e-10) + sol, info = npx.minres(A, b, tol=1.0e-10) + sol, info = npx.gmres(A, b, tol=1.0e-10) + ``` + `sol` is the solution of the linear system `A @ x = b` (or `None` if no convergence), + and `info` contains some useful data, e.g., `info.resnorms`. + ### License npx is published under the [MIT license](https://en.wikipedia.org/wiki/MIT_License). From 0c59ad9598f9a73964e917eb0305681ff36dc724 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Thu, 25 Mar 2021 23:29:10 +0100 Subject: [PATCH 5/6] more readme --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8e7b8f7..6f1e74a 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,9 @@ to PR here. sol, info = npx.gmres(A, b, tol=1.0e-10) ``` `sol` is the solution of the linear system `A @ x = b` (or `None` if no convergence), - and `info` contains some useful data, e.g., `info.resnorms`. + and `info` contains some useful data, e.g., `info.resnorms`. The methods are wrappers + around [SciPy's iterative + solvers](https://docs.scipy.org/doc/scipy/reference/sparse.linalg.html). ### License npx is published under the [MIT license](https://en.wikipedia.org/wiki/MIT_License). From 81ce22fc22a616a4be8251a1a7e17157367dc6bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20Schl=C3=B6mer?= Date: Thu, 25 Mar 2021 23:34:45 +0100 Subject: [PATCH 6/6] add scipy dependency --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 312f270..b14e34b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ packages = find: install_requires = importlib_metadata;python_version<"3.8" numpy + scipy python_requires = >=3.6 [options.entry_points]