Skip to content

Commit

Permalink
Merge pull request #13 from nschloe/krylov-first-residual
Browse files Browse the repository at this point in the history
Krylov first residuals
  • Loading branch information
nschloe authored Apr 25, 2021
2 parents 7209683 + cc5e0ae commit 0fd2ef1
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 49 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ and `info` contains some useful data, e.g., `info.resnorms`. The methods are wra
around [SciPy's iterative
solvers](https://docs.scipy.org/doc/scipy/reference/sparse.linalg.html).

Relevant issues:
* [inconsistent number of callback calls between cg, minres](https://github.com/scipy/scipy/issues/13936)


#### SciPy minimization
```python
Expand Down
59 changes: 35 additions & 24 deletions npx/_krylov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import scipy
import scipy.sparse.linalg

Info = namedtuple("KrylovInfo", ["success", "xk", "resnorms", "errnorms"])
Info = namedtuple("KrylovInfo", ["success", "xk", "numsteps", "resnorms", "errnorms"])


def cg(
Expand All @@ -19,18 +19,27 @@ def cg(
atol: Optional[float] = 0.0,
exact_solution=None,
):
resnorms = []

if x0 is None:
x0 = np.zeros(A.shape[1])

# initial residual
resnorms = []
r = b - A @ x0
Mr = r if M is None else M @ r
resnorms.append(np.sqrt(np.dot(r, Mr)))

if exact_solution is None:
errnorms = None
else:
err = exact_solution - x0
errnorms = [np.sqrt(np.dot(err, err))]

num_steps = 0

def cb(xk):
nonlocal num_steps
num_steps += 1

if callback is not None:
callback(xk)

Expand All @@ -48,11 +57,7 @@ def cb(xk):

success = info == 0

resnorms = np.array(resnorms)
if errnorms is not None:
errnorms = np.array(errnorms)

return x if success else None, Info(success, x, resnorms, errnorms)
return x if success else None, Info(success, x, num_steps, resnorms, errnorms)


def gmres(
Expand All @@ -67,18 +72,23 @@ def gmres(
atol: Optional[float] = 0.0,
exact_solution=None,
):
resnorms = []

if x0 is None:
x0 = np.zeros(A.shape[1])

# scipy.gmres() apparently calls the callback before the start of the iteration such
# that the initial residual is automatically contained
resnorms = []
num_steps = -1

if exact_solution is None:
errnorms = None
else:
err = exact_solution - x0
errnorms = [np.sqrt(np.dot(err, err))]
errnorms = []

def cb(xk):
nonlocal num_steps
num_steps += 1

if callback is not None:
callback(xk)

Expand All @@ -105,11 +115,7 @@ def cb(xk):

success = info == 0

resnorms = np.array(resnorms)
if errnorms is not None:
errnorms = np.array(errnorms)

return x if success else None, Info(success, x, resnorms, errnorms)
return x if success else None, Info(success, x, num_steps, resnorms, errnorms)


def minres(
Expand All @@ -123,18 +129,27 @@ def minres(
callback: Optional[Callable] = None,
exact_solution=None,
):
resnorms = []

if x0 is None:
x0 = np.zeros(A.shape[1])

# initial residual
resnorms = []
r = b - A @ x0
Mr = r if M is None else M @ r
resnorms.append(np.sqrt(np.dot(r, Mr)))

if exact_solution is None:
errnorms = None
else:
err = exact_solution - x0
errnorms = [np.sqrt(np.dot(err, err))]

num_steps = 0

def cb(xk):
nonlocal num_steps
num_steps += 1

if callback is not None:
callback(xk)

Expand All @@ -152,8 +167,4 @@ def cb(xk):

success = info == 0

resnorms = np.array(resnorms)
if errnorms is not None:
errnorms = np.array(errnorms)

return x if success else None, Info(success, x, resnorms, errnorms)
return x if success else None, Info(success, x, num_steps, resnorms, errnorms)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = npx
version = 0.0.10
version = 0.0.11
author = Nico Schlömer
author_email = nico.schloemer@gmail.com
description = Some useful extensions for NumPy
Expand Down
57 changes: 33 additions & 24 deletions test/test_krylov.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,50 @@ def _run(method, resnorms1, resnorms2, tol=1.0e-13):

exact_solution = scipy.sparse.linalg.spsolve(A, b)

sol, info = method(A, b, exact_solution=exact_solution)
x0 = np.zeros(A.shape[1])
sol, info = method(A, b, x0, exact_solution=exact_solution, callback=lambda _: None)
assert sol is not None
assert info.success
resnorms1 = np.asarray(resnorms1)
for x in info.resnorms:
print(f"{x:.15e}")
print(info)
assert len(info.resnorms) == info.numsteps + 1
assert len(info.errnorms) == info.numsteps + 1
print(info.resnorms)
print()
resnorms1 = np.asarray(resnorms1)
assert np.all(np.abs(info.resnorms - resnorms1) < tol * (1 + resnorms1))
# make sure the initial resnorm and errnorm are correct
assert abs(np.linalg.norm(A @ x0 - b, 2) - info.resnorms[0]) < 1.0e-13
assert abs(np.linalg.norm(x0 - exact_solution, 2) - info.errnorms[0]) < 1.0e-13

# with "preconditioning"
M = scipy.sparse.linalg.LinearOperator((n, n), matvec=lambda x: 0.5 * x)
sol, info = method(A, b, M=M)

assert sol is not None
assert info.success
print(info.resnorms)
resnorms2 = np.asarray(resnorms2)
for x in info.resnorms:
print(f"{x:.15e}")
assert np.all(np.abs(info.resnorms - resnorms2) < tol * (1 + resnorms2))


def test_cg():
_run(
npx.cg,
[
6.324555320336759e00,
4.898979485566356e00,
3.464101615137754e00,
2.000000000000000e00,
0.000000000000000e00,
3.1622776601683795,
6.324555320336759,
4.898979485566356,
3.4641016151377544,
2.0,
0.0,
],
[
4.472135954999580e00,
3.464101615137754e00,
2.449489742783178e00,
1.414213562373095e00,
0.000000000000000e00,
2.23606797749979,
4.47213595499958,
3.4641016151377544,
2.449489742783178,
1.4142135623730951,
0.0,
],
)

Expand All @@ -67,17 +74,19 @@ def test_minres():
_run(
npx.minres,
[
2.828427124746190e00,
2.449489742783178e00,
2.000000000000000e00,
1.414213562373095e00,
3.1622776601683795,
2.8284271247461903,
2.449489742783178,
2.0,
1.4142135623730951,
8.747542958250513e-15,
],
[
2.000000000000000e00,
1.732050807568877e00,
1.414213562373095e00,
1.000000000000000e00,
2.23606797749979,
2.0,
1.7320508075688772,
1.4142135623730951,
1.0,
5.475099487534308e-15,
],
)
Expand Down

0 comments on commit 0fd2ef1

Please sign in to comment.