Skip to content

Commit

Permalink
Fix a crash and an unclear error message.
Browse files Browse the repository at this point in the history
1. Fixes a spurious crash when using an implicit solver with DirectAdjoint.
2. Fixes the unclear error message when using an implicit solver without an adaptive step size controller.
  • Loading branch information
patrick-kidger committed Jan 8, 2024
1 parent e240aea commit 7965e89
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 24 deletions.
4 changes: 3 additions & 1 deletion diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def loop(
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none
)
inner_while_loop = ft.partial(_inner_loop, kind=kind)
outer_while_loop = ft.partial(_outer_loop, kind=kind)
final_state = self._loop(
Expand Down
61 changes: 39 additions & 22 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings
from collections.abc import Callable
from typing import Any, cast, get_args, get_origin, Optional, Tuple, TYPE_CHECKING
from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING

import equinox as eqx
import equinox.internal as eqxi
Expand Down Expand Up @@ -736,27 +736,44 @@ def _wrap(term):
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
)

if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
if isinstance(solver, AbstractImplicitSolver):
if solver.root_finder.rtol is use_stepsize_tol:
solver = eqx.tree_at(
lambda s: s.root_finder.rtol,
solver,
stepsize_controller.rtol,
)
solver = cast(AbstractImplicitSolver, solver)
if solver.root_finder.atol is use_stepsize_tol:
solver = eqx.tree_at(
lambda s: s.root_finder.atol,
solver,
stepsize_controller.atol,
)
solver = cast(AbstractImplicitSolver, solver)
if solver.root_finder.norm is use_stepsize_tol:
solver = eqx.tree_at(
lambda s: s.root_finder.norm,
solver,
stepsize_controller.norm,
if isinstance(solver, AbstractImplicitSolver):

def _get_tols(x):
outs = []
for attr in ("rtol", "atol", "norm"):
if getattr(solver.root_finder, attr) is use_stepsize_tol: # pyright: ignore
outs.append(getattr(x, attr))
return tuple(outs)

if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
solver = eqx.tree_at(
lambda s: _get_tols(s.root_finder),
solver,
_get_tols(stepsize_controller),
)
else:
if len(_get_tols(solver.root_finder)) > 0:
raise ValueError(
"A fixed step size controller is being used alongside an implicit "
"solver, but the tolerances for the implicit solver have not been "
"specified. (Being unspecified is the default in Diffrax.)\n"
"The correct fix is almost always to use an adaptive step size "
"controller. For example "
"`diffrax.diffeqsolve(..., "
"stepsize_controller=diffrax.PIDController(rtol=..., atol=...))`. "
"In this case the same tolerances are used for the implicit "
"solver as are used to control the adaptive stepping.\n"
"(Note for advanced users: the tolerances for the implicit "
"solver can also be explicitly set instead. For example "
"`diffrax.diffeqsolve(..., solver=diffrax.Kvaerno5(root_finder="
"diffrax.VeryChord(rtol=..., atol=..., "
"norm=optimistix.max_norm)))`. In this case the norm must also be "
"explicitly specified.)\n"
"Adaptive step size controllers are the preferred solution, as "
"sometimes the implicit solver may fail to converge, and in this "
"case an adaptive step size controller can reject the step and try "
"a smaller one, whilst with a fixed step size controller the "
"overall differential equation solve will simply fail."
)

# Error checking
Expand Down
10 changes: 9 additions & 1 deletion diffrax/_root_finder/_with_tols.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import optimistix as optx


use_stepsize_tol = object()
class _UseStepSizeTol:
def __repr__(self):
return (
"<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` "
"argument>"
)


use_stepsize_tol = _UseStepSizeTol()


def with_stepsize_controller_tols(cls: type[optx.AbstractRootFinder]):
Expand Down
13 changes: 13 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,16 @@ def run(y0__args, adjoint):
grads3 = run((y0, args), diffrax.RecursiveCheckpointAdjoint())
assert tree_allclose(grads1, grads2, rtol=1e-3, atol=1e-3)
assert tree_allclose(grads1, grads3, rtol=1e-3, atol=1e-3)


def test_implicit_runge_kutta_direct_adjoint():
diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: -y),
diffrax.Kvaerno5(),
0,
1,
0.01,
1.0,
adjoint=diffrax.DirectAdjoint(),
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
)
13 changes: 13 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,16 @@ def vector_field(t, y, args):
assert text == "static_made_jump=False static_result=None\n"
finally:
diffrax._integrate._PRINT_STATIC = False


def test_implicit_tol_error():
msg = "the tolerances for the implicit solver have not been specified"
with pytest.raises(ValueError, match=msg):
diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: -y),
diffrax.Kvaerno5(),
0,
1,
0.01,
1.0,
)

0 comments on commit 7965e89

Please sign in to comment.