From dee6fc78aa22942fa20bf3d30ec5b54a951ef6ba Mon Sep 17 00:00:00 2001 From: Andrius Ovsianas Date: Mon, 7 Nov 2022 21:12:46 +0000 Subject: [PATCH 1/5] Brownian motion classes accept pytrees for shape and dtype arguments (#183) * Brownian motion classes accept pytree shapes and dtypes * adjusted brownian tests for pytree shapes and dtypes * Change docs for brownian motion class arguments shape and dtype * Make shape of type Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]] * Change error_if to ValueError on shape check * Better shape documentation * Typing bug: compatability with python 3.8 * sys.modules error on testing --- diffrax/brownian/base.py | 6 ++-- diffrax/brownian/path.py | 50 +++++++++++++++++++++----- diffrax/brownian/tree.py | 77 +++++++++++++++++++++++++++++++--------- diffrax/misc/__init__.py | 2 ++ diffrax/misc/misc.py | 15 +++++++- test/conftest.py | 2 +- test/test_brownian.py | 71 ++++++++++++++++++++++++++++++++++-- 7 files changed, 190 insertions(+), 33 deletions(-) diff --git a/diffrax/brownian/base.py b/diffrax/brownian/base.py index 0bf0d4a7..ee90ba60 100644 --- a/diffrax/brownian/base.py +++ b/diffrax/brownian/base.py @@ -1,6 +1,6 @@ import abc -from ..custom_types import Array, Scalar +from ..custom_types import Array, PyTree, Scalar from ..path import AbstractPath @@ -8,7 +8,7 @@ class AbstractBrownianPath(AbstractPath): "Abstract base class for all Brownian paths." @abc.abstractmethod - def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array: + def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: r"""Samples a Brownian increment $w(t_1) - w(t_0)$. Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$. @@ -23,7 +23,7 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array: **Returns:** - A JAX array corresponding to the increment $w(t_1) - w(t_0)$. + A pytree of JAX arrays corresponding to the increment $w(t_1) - w(t_0)$. Some subclasses may allow `t1=None`, in which case just the value $w(t_0)$ is returned. diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index 933a0926..a7f16a4a 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -1,11 +1,18 @@ -from typing import Tuple +from typing import Tuple, Union import equinox as eqx +import jax import jax.numpy as jnp import jax.random as jrandom - -from ..custom_types import Array, Scalar -from ..misc import force_bitcast_convert_type, nondifferentiable_input +import jax.tree_util as jtu + +from ..custom_types import Array, PyTree, Scalar +from ..misc import ( + force_bitcast_convert_type, + is_tuple_of_ints, + nondifferentiable_input, + split_by_tree, +) from .base import AbstractBrownianPath @@ -29,11 +36,26 @@ class UnsafeBrownianPath(AbstractBrownianPath): correlation structure isn't needed.) """ - shape: Tuple[int] = eqx.static_field() + shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field() # Handled as a string because PRNGKey is actually a function, not a class, which # makes it appearly badly in autogenerated documentation. key: "jax.random.PRNGKey" # noqa: F821 + def __init__( + self, + shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], + key: "jax.random.PRNGKey", + ): + self.shape = ( + jax.ShapeDtypeStruct(shape, None) if is_tuple_of_ints(shape) else shape + ) + self.key = key + if any( + not jnp.issubdtype(x.dtype, jnp.inexact) + for x in jtu.tree_leaves(self.shape) + ): + raise ValueError("UnsafeBrownianPath dtypes all have to be floating-point.") + @property def t0(self): return None @@ -43,7 +65,7 @@ def t1(self): return None @eqx.filter_jit - def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array: + def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: del left nondifferentiable_input(t0, "t0") nondifferentiable_input(t1, "t1") @@ -51,12 +73,24 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array: t1_ = force_bitcast_convert_type(t1, jnp.int32) key = jrandom.fold_in(self.key, t0_) key = jrandom.fold_in(key, t1_) - return jrandom.normal(key, self.shape) * jnp.sqrt(t1 - t0) + key = split_by_tree(key, self.shape) + return jtu.tree_map( + lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape + ) + + def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruct): + return jrandom.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype( + shape.dtype + ) UnsafeBrownianPath.__init__.__doc__ = """ **Arguments:** -- `shape`: What shape each individual Brownian sample should be. +- `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape, +dtype, and PyTree structure of the output. For simplicity, `shape` can also just +be a tuple of integers, describing the shape of a single JAX array. In that case +the dtype is chosen to be `float64` if `JAX_ENABLE_X64=True` and `float32` +otherwise. - `key`: A random key. """ diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 6760f3d2..e2e7b1c8 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -1,14 +1,15 @@ from dataclasses import field -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import equinox as eqx import jax import jax.lax as lax import jax.numpy as jnp import jax.random as jrandom +import jax.tree_util as jtu -from ..custom_types import Array, Scalar -from ..misc import error_if +from ..custom_types import Array, PyTree, Scalar +from ..misc import error_if, is_tuple_of_ints, split_by_tree from .base import AbstractBrownianPath @@ -58,29 +59,67 @@ class VirtualBrownianTree(AbstractBrownianPath): t0: Scalar = field(init=True) t1: Scalar = field(init=True) # override init=False in AbstractPath tol: Scalar - shape: Tuple[int] = eqx.static_field() + shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field() key: "jax.random.PRNGKey" # noqa: F821 + def __init__( + self, + t0: Scalar, + t1: Scalar, + tol: Scalar, + shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], + key: "jax.random.PRNGKey", + ): + self.t0 = t0 + self.t1 = t1 + self.tol = tol + self.shape = ( + jax.ShapeDtypeStruct(shape, None) if is_tuple_of_ints(shape) else shape + ) + if any( + not jnp.issubdtype(x.dtype, jnp.inexact) + for x in jtu.tree_leaves(self.shape) + ): + raise ValueError( + "VirtualBrownianTree dtypes all have to be floating-point." + ) + self.key = split_by_tree(key, self.shape) + @eqx.filter_jit def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True - ) -> Array: + ) -> PyTree[Array]: del left if t1 is None: return self._evaluate(t0) else: - return self._evaluate(t1) - self._evaluate(t0) + return jtu.tree_map( + lambda x, y: x - y, + self._evaluate(t1), + self._evaluate(t0), + ) - def _brownian_bridge(self, s, t, u, w_s, w_u, key): + def _evaluate(self, τ: Scalar) -> PyTree[Array]: + map_func = lambda key, shape: self._evaluate_leaf(key, τ, shape) + return jtu.tree_map(map_func, self.key, self.shape) + + def _brownian_bridge(self, s, t, u, w_s, w_u, key, shape, dtype): mean = w_s + (w_u - w_s) * ((t - s) / (u - s)) var = (u - t) * (t - s) / (u - s) std = jnp.sqrt(var) - return mean + std * jrandom.normal(key, self.shape) + return mean + std * jrandom.normal(key, shape, dtype) + + def _evaluate_leaf( + self, + key, + τ: Scalar, + shape: jax.ShapeDtypeStruct, + ) -> Array: + shape, dtype = shape.shape, shape.dtype - def _evaluate(self, τ: Scalar) -> Array: cond = self.t0 < self.t1 - t0 = jnp.where(cond, self.t0, self.t1) - t1 = jnp.where(cond, self.t1, self.t0) + t0 = jnp.where(cond, self.t0, self.t1).astype(dtype) + t1 = jnp.where(cond, self.t1, self.t0).astype(dtype) error_if( τ < t0, "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1]." @@ -90,12 +129,12 @@ def _evaluate(self, τ: Scalar) -> Array: ) # Clip because otherwise the while loop below won't terminate, and the above # errors are only raised after everything has finished executing. - τ = jnp.clip(τ, t0, t1) + τ = jnp.clip(τ, t0, t1).astype(dtype) - key, init_key = jrandom.split(self.key, 2) + key, init_key = jrandom.split(key, 2) thalf = t0 + 0.5 * (t1 - t0) - w_t1 = jrandom.normal(init_key, self.shape) * jnp.sqrt(t1 - t0) - w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key) + w_t1 = jrandom.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0) + w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key, shape, dtype) init_state = _State( s=t0, t=thalf, @@ -124,7 +163,7 @@ def _body_fun(_state): _w_u = jnp.where(_cond, _state.w_u, _state.w_t) _key = jnp.where(_cond, _key1, _key2) _t = _s + 0.5 * (_u - _s) - _w_t = self._brownian_bridge(_s, _t, _u, _w_s, _w_u, _key) + _w_t = self._brownian_bridge(_s, _t, _u, _w_s, _w_u, _key, shape, dtype) return _State(s=_s, t=_t, u=_u, w_s=_w_s, w_t=_w_t, w_u=_w_u, key=_key) final_state = lax.while_loop(_cond_fun, _body_fun, init_state) @@ -162,7 +201,11 @@ def _body_fun(_state): - `t0`: The start of the interval the Brownian motion is defined over. - `t1`: The start of the interval the Brownian motion is defined over. - `tol`: The discretisation that `[t0, t1]` is discretised to. -- `shape`: What shape each individual Brownian sample should be. +- `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape, +dtype, and PyTree structure of the output. For simplicity, `shape` can also just +be a tuple of integers, describing the shape of a single JAX array. In that case +the dtype is chosen to be `float64` if `JAX_ENABLE_X64=True` and `float32` +otherwise. - `key`: A random key. !!! info diff --git a/diffrax/misc/__init__.py b/diffrax/misc/__init__.py index 74e7dbc3..d72e095b 100644 --- a/diffrax/misc/__init__.py +++ b/diffrax/misc/__init__.py @@ -11,9 +11,11 @@ ContainerMeta, fill_forward, force_bitcast_convert_type, + is_tuple_of_ints, left_broadcast_to, linear_rescale, rms_norm, + split_by_tree, ) from .nextafter import nextafter, prevbefore from .omega import ω diff --git a/diffrax/misc/misc.py b/diffrax/misc/misc.py index 875e731a..379cdc3d 100644 --- a/diffrax/misc/misc.py +++ b/diffrax/misc/misc.py @@ -1,9 +1,10 @@ -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import jax import jax.flatten_util as fu import jax.lax as lax import jax.numpy as jnp +import jax.tree_util as jtu from ..custom_types import Array, PyTree, Scalar @@ -161,3 +162,15 @@ def left_broadcast_to(arr, shape): indices = tuple(slice(None) if i < arr.ndim else None for i in range(len(shape))) return jnp.broadcast_to(arr[indices], shape) + + +def split_by_tree(key, tree, is_leaf: Optional[Callable[[PyTree], bool]] = None): + """Like jax.random.split but accepts tree as a second argument and produces + a tree of keys with the same structure. + """ + treedef = jtu.tree_structure(tree, is_leaf=is_leaf) + return jtu.tree_unflatten(treedef, jax.random.split(key, treedef.num_leaves)) + + +def is_tuple_of_ints(obj): + return isinstance(obj, tuple) and all(isinstance(x, int) for x in obj) diff --git a/test/conftest.py b/test/conftest.py index 51a9358f..612698af 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -31,7 +31,7 @@ def clear_caches(): process = psutil.Process() if process.memory_info().vms > 4 * 2**30: # >4GB memory usage jax.clear_backends() - for module_name, module in sys.modules.items(): + for module_name, module in sys.modules.copy().items(): if module_name.startswith("jax"): if module_name not in ["jax.interpreters.partial_eval"]: for obj_name in dir(module): diff --git a/test/test_brownian.py b/test/test_brownian.py index 1ee2ec15..f22fece5 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import jax.random as jrandom +import jax.tree_util as jtu import pytest import scipy.stats as stats @@ -19,10 +20,64 @@ @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -def test_shape(ctr, getkey): +def test_shape_and_dtype(ctr, getkey): t0 = 0 t1 = 2 - for shape in ((0,), (1, 0), (2,), (3, 4), (1, 2, 3, 4)): + + shapes = ( + (0,), + ( + 1, + 0, + ), + (2,), + (3, 4), + (1, 2, 3, 4), + { + "a": (1,), + "b": ( + 2, + 3, + ), + }, + ( + ( + 1, + 2, + ), + ( + ( + 3, + 4, + ), + ( + 5, + 6, + ), + ), + ), + ) + + dtypes = ( + None, + None, + jnp.float16, + jnp.float32, + jnp.float64, + {"a": None, "b": jnp.float64}, + (jnp.float16, (jnp.float32, jnp.float64)), + ) + + def is_tuple_of_ints(obj): + return isinstance(obj, tuple) and all(isinstance(x, int) for x in obj) + + for shape, dtype in zip(shapes, dtypes): + # Shape to pass as input + if dtype is not None: + shape = jtu.tree_map( + jax.ShapeDtypeStruct, shape, dtype, is_leaf=is_tuple_of_ints + ) + if ctr is diffrax.UnsafeBrownianPath: path = ctr(shape, getkey()) assert path.t0 is None @@ -34,12 +89,22 @@ def test_shape(ctr, getkey): assert path.t1 == 2 else: assert False + + # Expected output shape + if dtype is None: + shape = jtu.tree_map( + jax.ShapeDtypeStruct, shape, dtype, is_leaf=is_tuple_of_ints + ) + for _t0 in _vals.values(): for _t1 in _vals.values(): t0, _ = _t0 _, t1 = _t1 out = path.evaluate(t0, t1) - assert out.shape == shape + out_shape = jtu.tree_map( + lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), out + ) + assert out_shape == shape @pytest.mark.parametrize( From 316fca42ce85a7151d6876fa8f22fd387a3f84b0 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 7 Nov 2022 13:14:55 -0800 Subject: [PATCH 2/5] Doc fix --- diffrax/brownian/path.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index a7f16a4a..eaee73b8 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -88,9 +88,8 @@ def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruc **Arguments:** - `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape, -dtype, and PyTree structure of the output. For simplicity, `shape` can also just -be a tuple of integers, describing the shape of a single JAX array. In that case -the dtype is chosen to be `float64` if `JAX_ENABLE_X64=True` and `float32` -otherwise. + dtype, and PyTree structure of the output. For simplicity, `shape` can also just + be a tuple of integers, describing the shape of a single JAX array. In that case + the dtype is chosen to be the default floating-point dtype. - `key`: A random key. """ From 0920550f008f10f2ef841429d4b81c816b08c01a Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 7 Nov 2022 13:15:15 -0800 Subject: [PATCH 3/5] Doc fix. --- diffrax/brownian/tree.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index e2e7b1c8..8507acf2 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -202,10 +202,9 @@ def _body_fun(_state): - `t1`: The start of the interval the Brownian motion is defined over. - `tol`: The discretisation that `[t0, t1]` is discretised to. - `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape, -dtype, and PyTree structure of the output. For simplicity, `shape` can also just -be a tuple of integers, describing the shape of a single JAX array. In that case -the dtype is chosen to be `float64` if `JAX_ENABLE_X64=True` and `float32` -otherwise. + dtype, and PyTree structure of the output. For simplicity, `shape` can also just + be a tuple of integers, describing the shape of a single JAX array. In that case + the dtype is chosen to be the default floating-point dtype. - `key`: A random key. !!! info From df79f8e3c122c7d465cd38c5cda7d027ac6043fb Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 14 Nov 2022 22:50:58 -0800 Subject: [PATCH 4/5] Upgraded to eqx.internal. Performance improvements. --- benchmarks/compile_times.py | 80 ++++++++ benchmarks/small_neural_ode.py | 9 +- diffrax/__init__.py | 2 +- diffrax/adjoint.py | 139 ++++++++++--- diffrax/brownian/path.py | 12 +- diffrax/brownian/tree.py | 15 +- diffrax/global_interpolation.py | 15 +- diffrax/integrate.py | 103 ++++++---- diffrax/local_interpolation.py | 7 +- diffrax/misc/__init__.py | 12 +- diffrax/misc/ad.py | 73 ------- diffrax/misc/bounded_while_loop.py | 5 +- diffrax/misc/errors.py | 36 ---- diffrax/misc/misc.py | 52 ++--- diffrax/misc/nextafter.py | 27 --- diffrax/misc/omega.py | 231 ---------------------- diffrax/misc/unvmap.py | 111 ----------- diffrax/path.py | 4 +- diffrax/solution.py | 8 +- diffrax/solver/dopri8.py | 3 +- diffrax/solver/euler.py | 2 +- diffrax/solver/euler_heun.py | 2 +- diffrax/solver/implicit_euler.py | 2 +- diffrax/solver/leapfrog_midpoint.py | 2 +- diffrax/solver/milstein.py | 2 +- diffrax/solver/reversible_heun.py | 2 +- diffrax/solver/runge_kutta.py | 86 +++++--- diffrax/solver/semi_implicit_euler.py | 2 +- diffrax/solver/tsit5.py | 3 +- diffrax/step_size_controller/adaptive.py | 59 +++--- diffrax/step_size_controller/constant.py | 10 +- diffrax/term.py | 2 +- docs/api/adjoints.md | 2 +- docs/devdocs/omega.md | 49 ----- docs/further_details/faq.md | 3 +- mkdocs.yml | 1 - setup.py | 2 +- test/test_adaptive_stepsize_controller.py | 25 +++ test/test_adjoint.py | 23 +-- test/test_brownian.py | 27 +-- test/test_detest.py | 5 +- test/test_integrate.py | 2 +- test/test_misc.py | 147 +------------- test/test_vmap.py | 3 +- 44 files changed, 478 insertions(+), 929 deletions(-) create mode 100644 benchmarks/compile_times.py delete mode 100644 diffrax/misc/errors.py delete mode 100644 diffrax/misc/nextafter.py delete mode 100644 diffrax/misc/omega.py delete mode 100644 diffrax/misc/unvmap.py delete mode 100644 docs/devdocs/omega.md diff --git a/benchmarks/compile_times.py b/benchmarks/compile_times.py new file mode 100644 index 00000000..f9598c7b --- /dev/null +++ b/benchmarks/compile_times.py @@ -0,0 +1,80 @@ +import functools as ft +import timeit + +import diffrax as dfx +import equinox as eqx +import fire +import jax +import jax.numpy as jnp +import jax.random as jr + + +def _weight(in_, out, key): + return [[w_ij for w_ij in w_i] for w_i in jr.normal(key, (out, in_))] + + +class VectorField(eqx.Module): + weights: list + + def __init__(self, in_, out, width, depth, *, key): + keys = jr.split(key, depth + 1) + self.weights = [_weight(in_, width, keys[0])] + for i in range(1, depth): + self.weights.append(_weight(width, width, keys[i])) + self.weights.append(_weight(width, out, keys[depth])) + + def __call__(self, t, y, args): + # Inefficient computation graph to make a toy example more expensive. + y = [y_i for y_i in y] + for w in self.weights: + y = [sum(w_ij * y_j for w_ij, y_j in zip(w_i, y)) for w_i in w] + return jnp.stack(y) + + +def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str): + if adjoint == "direct": + adjoint = dfx.DirectAdjoint() + elif adjoint == "recursive": + adjoint = dfx.RecursiveCheckpointAdjoint() + elif adjoint == "backsolve": + adjoint = dfx.BacksolveAdjoint() + else: + raise ValueError + if grad: + grad_decorator = jax.grad + else: + grad_decorator = lambda x: x + + vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0)) + if not inline: + vf = eqx.internal.noinline(vf) + term = dfx.ODETerm(vf) + solver = dfx.Dopri8(scan_stages=scan_stages) + stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-6) + t0 = 0 + t1 = 1 + dt0 = 0.01 + + @jax.jit + @grad_decorator + def solve(y0): + sol = dfx.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + stepsize_controller=stepsize_controller, + adjoint=adjoint, + max_steps=16**2, + ) + return jnp.sum(sol.ys) + + solve_ = ft.partial(solve, jnp.array([1.0])) + print("Compile+run time", timeit.timeit(solve_, number=1)) + print("Run time", timeit.timeit(solve_, number=1)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/benchmarks/small_neural_ode.py b/benchmarks/small_neural_ode.py index 5a9d7e7c..95eb2260 100644 --- a/benchmarks/small_neural_ode.py +++ b/benchmarks/small_neural_ode.py @@ -9,6 +9,7 @@ import jax.nn as jnn import jax.numpy as jnp import jax.random as jrandom +import numpy as np import torch import torchdiffeq @@ -173,10 +174,10 @@ def main(batch_size=64, t1=100, multiple=False, grad=False): with torch.no_grad(): func_jax = neural_ode_diffrax.func.func func_torch = neural_ode_torch.func.func - func_torch[0].weight.copy_(torch.tensor(func_jax.layers[0].weight.to_py())) - func_torch[0].bias.copy_(torch.tensor(func_jax.layers[0].bias.to_py())) - func_torch[2].weight.copy_(torch.tensor(func_jax.layers[1].weight.to_py())) - func_torch[2].bias.copy_(torch.tensor(func_jax.layers[1].bias.to_py())) + func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight))) + func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias))) + func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) + func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4)) y0_torch = torch.tensor(y0_jax.to_py()) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 07f2c669..a038518a 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -87,4 +87,4 @@ ) -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index f4452a81..981447be 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -2,15 +2,67 @@ from typing import Any, Dict import equinox as eqx +import equinox.internal as eqxi import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu +from equinox.internal import ω -from .misc import implicit_jvp, nondifferentiable_output, ω +from .misc import implicit_jvp from .saveat import SaveAt from .term import AbstractTerm, AdjointTerm +def _is_none(x): + return x is None + + +def _no_transpose_final_state(final_state): + y = eqxi.nondifferentiable_backward(final_state.y, name="y") + tprev = eqxi.nondifferentiable_backward(final_state.tprev, name="tprev") + tnext = eqxi.nondifferentiable_backward(final_state.tnext, name="tnext") + solver_state = eqxi.nondifferentiable_backward( + final_state.solver_state, name="solver_state" + ) + controller_state = eqxi.nondifferentiable_backward( + final_state.controller_state, name="controller_state" + ) + ts = eqxi.nondifferentiable_backward(final_state.ts, name="ts") + ys = final_state.ys + dense_ts = eqxi.nondifferentiable_backward(final_state.dense_ts, name="dense_ts") + dense_infos = eqxi.nondifferentiable_backward( + final_state.dense_infos, name="dense_infos" + ) + final_state = eqxi.nondifferentiable_backward(final_state) # no more specific name + final_state = eqx.tree_at( + lambda s: ( + s.y, + s.tprev, + s.tnext, + s.solver_state, + s.controller_state, + s.ts, + s.ys, + s.dense_ts, + s.dense_infos, + ), + final_state, + ( + y, + tprev, + tnext, + solver_state, + controller_state, + ts, + ys, + dense_ts, + dense_infos, + ), + is_leaf=_is_none, + ) + return final_state + + class AbstractAdjoint(eqx.Module): """Abstract base class for all adjoint methods.""" @@ -30,6 +82,8 @@ def loop( max_steps, throw, init_state, + passed_solver_state, + passed_controller_state, ): """Runs the main solve loop. Subclasses can override this to provide custom backpropagation behaviour; see for example the implementation of @@ -69,27 +123,26 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint): For most problems this is the preferred technique for backpropagating through a differential equation. - A binomial checkpointing scheme is used so that memory usage is low. + In addition a binomial checkpointing scheme is used so that memory usage is low. + (This checkpointing can increase compile time a bit, though.) """ - def loop(self, *, throw, **kwargs): - del throw + def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): + del throw, passed_solver_state, passed_controller_state return self._loop_fn(**kwargs, is_bounded=True) class NoAdjoint(AbstractAdjoint): """Disable backpropagation through [`diffrax.diffeqsolve`][]. - Forward-mode autodifferentiation (`jax.jvp`) will continue to work as normal. - If you do not need to differentiate the results of [`diffrax.diffeqsolve`][] then this may sometimes improve the speed at which the differential equation is solved. """ - def loop(self, *, throw, **kwargs): - del throw + def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs): + del throw, passed_solver_state, passed_controller_state final_state, aux_stats = self._loop_fn(**kwargs, is_bounded=False) - final_state = jtu.tree_map(nondifferentiable_output, final_state) + final_state = eqxi.nondifferentiable_backward(final_state) return final_state, aux_stats @@ -135,7 +188,19 @@ class ImplicitAdjoint(AbstractAdjoint): via the implicit function theorem. """ # noqa: E501 - def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs): + def loop( + self, + *, + args, + terms, + solver, + saveat, + throw, + init_state, + passed_solver_state, + passed_controller_state, + **kwargs, + ): del throw # `is` check because this may return a Tracer from SaveAt(ts=) @@ -144,21 +209,30 @@ def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs): "Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`." ) - init_state = eqx.tree_at( - lambda s: (s.y, s.solver_state, s.controller_state), - init_state, - replace_fn=lax.stop_gradient, - ) + if not passed_solver_state: + init_state = eqx.tree_at( + lambda s: s.solver_state, + init_state, + replace_fn=lax.stop_gradient, + is_leaf=_is_none, + ) + if not passed_controller_state: + init_state = eqx.tree_at( + lambda s: s.controller_state, + init_state, + replace_fn=lax.stop_gradient, + is_leaf=_is_none, + ) + closure = (self, kwargs, solver, saveat, init_state) ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure) final_state_no_ys, aux_stats = residual - return ( - eqx.tree_at( - lambda s: s.ys, final_state_no_ys, ys, is_leaf=lambda x: x is None - ), - aux_stats, + final_state = eqx.tree_at( + lambda s: s.ys, final_state_no_ys, ys, is_leaf=_is_none ) + final_state = _no_transpose_final_state(final_state) + return final_state, aux_stats # Compute derivatives with respect to the first argument: @@ -174,7 +248,7 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): ) del y return self._loop_fn( - args=args, terms=terms, init_state=init_state, **kwargs, is_bounded=False + args=args, terms=terms, init_state=init_state, is_bounded=False, **kwargs ) @@ -398,7 +472,18 @@ def __init__(self, **kwargs): ) self.kwargs = kwargs - def loop(self, *, args, terms, saveat, init_state, **kwargs): + def loop( + self, + *, + args, + terms, + saveat, + init_state, + passed_solver_state, + passed_controller_state, + **kwargs, + ): + del passed_solver_state, passed_controller_state if saveat.steps or saveat.dense: raise NotImplementedError( "Cannot use `adjoint=BacksolveAdjoint()` with " @@ -414,13 +499,5 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs): final_state, aux_stats = _loop_backsolve( (y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs ) - - # We only allow backpropagation through `ys`; in particular not through - # `solver_state` etc. - ys = final_state.ys - final_state = jtu.tree_map(nondifferentiable_output, final_state) - final_state = eqx.tree_at( - lambda s: jtu.tree_leaves(s.ys), final_state, jtu.tree_leaves(ys) - ) - + final_state = _no_transpose_final_state(final_state) return final_state, aux_stats diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index eaee73b8..e9c0136c 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -1,18 +1,14 @@ from typing import Tuple, Union import equinox as eqx +import equinox.internal as eqxi import jax import jax.numpy as jnp import jax.random as jrandom import jax.tree_util as jtu from ..custom_types import Array, PyTree, Scalar -from ..misc import ( - force_bitcast_convert_type, - is_tuple_of_ints, - nondifferentiable_input, - split_by_tree, -) +from ..misc import force_bitcast_convert_type, is_tuple_of_ints, split_by_tree from .base import AbstractBrownianPath @@ -67,8 +63,8 @@ def t1(self): @eqx.filter_jit def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: del left - nondifferentiable_input(t0, "t0") - nondifferentiable_input(t1, "t1") + t0 = eqxi.nondifferentiable(t0, name="t0") + t1 = eqxi.nondifferentiable(t1, name="t0") t0_ = force_bitcast_convert_type(t0, jnp.int32) t1_ = force_bitcast_convert_type(t1, jnp.int32) key = jrandom.fold_in(self.key, t0_) diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 8507acf2..38cb96e2 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union import equinox as eqx +import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp @@ -9,7 +10,7 @@ import jax.tree_util as jtu from ..custom_types import Array, PyTree, Scalar -from ..misc import error_if, is_tuple_of_ints, split_by_tree +from ..misc import is_tuple_of_ints, split_by_tree from .base import AbstractBrownianPath @@ -121,11 +122,15 @@ def _evaluate_leaf( t0 = jnp.where(cond, self.t0, self.t1).astype(dtype) t1 = jnp.where(cond, self.t1, self.t0).astype(dtype) - error_if( - τ < t0, "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1]." + t0 = eqxi.error_if( + t0, + τ < t0, + "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", ) - error_if( - τ > t1, "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1]." + t1 = eqxi.error_if( + t1, + τ > t1, + "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", ) # Clip because otherwise the while loop below won't terminate, and the above # errors are only raised after everything has finished executing. diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 83859a9e..a20271a4 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -2,14 +2,16 @@ from typing import Optional, Tuple, Type import equinox as eqx +import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu +from equinox.internal import ω from .custom_types import Array, DenseInfos, Int, PyTree, Scalar from .local_interpolation import AbstractLocalInterpolation -from .misc import error_if, fill_forward, left_broadcast_to, ω +from .misc import fill_forward, left_broadcast_to from .path import AbstractPath @@ -347,7 +349,10 @@ def _check_ts(ts: Array["times"]) -> None: # noqa: F821 if ts.shape[0] < 2: raise ValueError(f"`ts` must be of length at least 2; got {ts.shape[0]}") # Also catches any NaN times. - error_if(ts[:-1] >= ts[1:], "`ts` must be monotonically strictly increasing.") + ts = eqxi.error_if( + ts, ts[:-1] >= ts[1:], "`ts` must be monotonically strictly increasing." + ) + return ts def _interpolation_reverse( @@ -448,7 +453,7 @@ def linear_interpolation( As `ys`, but with `NaN` values filled in. """ - _check_ts(ts) + ts = _check_ts(ts) fn = ft.partial(_linear_interpolation, fill_forward_nans_at_end, ts) if replace_nans_at_start is None: return jtu.tree_map(fn, ys) @@ -538,7 +543,7 @@ def rectilinear_interpolation( are something we are free to pick. """ - _check_ts(ts) + ts = _check_ts(ts) if replace_nans_at_start is None: fn = ft.partial(_rectilinear_interpolation, ts, None) out = jtu.tree_map(fn, ys) @@ -710,7 +715,7 @@ def backward_hermite_coefficients( to `ts[1]`, and `ts[1]` to `ts[2]` etc. """ - _check_ts(ts) + ts = _check_ts(ts) fn = ft.partial(_backward_hermite_coefficients, fill_forward_nans_at_end, ts) if deriv0 is None: if replace_nans_at_start is None: diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 12f7a20e..1bf48e00 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -3,6 +3,7 @@ from typing import Optional import equinox as eqx +import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp @@ -18,14 +19,7 @@ from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation from .heuristics import is_sde, is_unsafe_sde -from .misc import ( - bounded_while_loop, - branched_error_if, - error_if, - HadInplaceUpdate, - unvmap_all, - unvmap_max, -) +from .misc import bounded_while_loop, HadInplaceUpdate from .saveat import SaveAt from .solution import is_okay, is_successful, RESULTS, Solution from .solver import AbstractItoSolver, AbstractSolver, AbstractStratonovichSolver, Euler @@ -33,6 +27,7 @@ AbstractAdaptiveStepSizeController, AbstractStepSizeController, ConstantStepSize, + PIDController, StepTo, ) from .term import AbstractTerm, WrapTerm @@ -120,6 +115,15 @@ def loop( dense_ts = dense_ts.at[0].set(t0) init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) + # Privileged optimisation for the common case of no jumps. We can reduce + # solver compile time with this. + # TODO: somehow make this a non-priviliged optimisation, i.e. detect when + # we can make jumps or not. + cannot_make_jump = isinstance(stepsize_controller, (ConstantStepSize, StepTo)) or ( + isinstance(stepsize_controller, PIDController) + and stepsize_controller.jump_ts is None + ) + def cond_fun(state): return (state.tprev < t1) & is_successful(state.result) @@ -137,7 +141,7 @@ def body_fun(state, inplace): state.y, args, state.solver_state, - state.made_jump, + False if cannot_make_jump else state.made_jump, ) # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that @@ -164,6 +168,16 @@ def body_fun(state, inplace): state.controller_state, ) assert jnp.result_type(keep_step) is jnp.dtype(bool) + if cannot_make_jump: + # Should hopefully get DCE'd out. + made_jump = eqxi.error_if( + made_jump, + made_jump, + ( + "Internal error in Diffrax: made unexpected jump. Please report an " + "issue at https://github.com/patrick-kidger/diffrax/issues" + ), + ) # # Do some book-keeping. @@ -418,7 +432,7 @@ def maybe_inplace(i, x, u): with jax.ensure_compile_time_eval(): def _is_finite(_t): - all_finite = unvmap_all(jnp.isfinite(_t)) + all_finite = eqxi.unvmap_all(jnp.isfinite(_t)) return not isinstance(all_finite, jax.core.Tracer) and all_finite if _is_finite(t0) and _is_finite(t1) and _is_finite(dt0): @@ -434,7 +448,7 @@ def _body_fun(_state): compiled_num_steps, _ = lax.while_loop( _cond_fun, _body_fun, (0, t0) ) - compiled_num_steps = unvmap_max(compiled_num_steps) + compiled_num_steps = eqxi.unvmap_max(compiled_num_steps) else: if stepsize_controller.compile_steps is None: compiled_num_steps = None @@ -556,9 +570,9 @@ def diffeqsolve( understand these. All of these are keyword-only arguments. - `adjoint`: How to backpropagate (and compute forward-mode autoderivatives) of - `diffeqsolve`. Defaults to discretise-then-optimise with recursive - checkpointing, which is usually the best option for most problems. See the page - on [Adjoints](./adjoints.md) for more information. + `diffeqsolve`. Defaults to discretise-then-optimise, which is usually the best + option for most problems. See the page on [Adjoints](./adjoints.md) for more + information. - `discrete_terminating_event`: A discrete event at which to terminate the solve early. See the page on [Events](./events.md) for more information. @@ -567,8 +581,9 @@ def diffeqsolve( unconditionally. Can also be set to `None` to allow an arbitrary number of steps, although this - is incompatible with `adjoint=RecursiveCheckpointAdjoint()` (the default) and - is incompatible with `saveat=SaveAt(steps=True)` or `saveat=SaveAt(dense=True)`. + is incompatible with `saveat=SaveAt(steps=True)` or `saveat=SaveAt(dense=True)`, + and can only be backpropagated through if using `adjoint=BacksolveAdjoint()` or + `adjoint=ImplicitAdjoint()`. Note that (a) compile times; and (b) backpropagation run times; will increase as `max_steps` increases. (Specifically, each time `max_steps` passes a power @@ -577,8 +592,9 @@ def diffeqsolve( - `throw`: Whether to raise an exception if the integration fails for any reason. - If `True` then an integration failure will either raise a `ValueError` (when - not using `jax.jit`) or print a warning message (when using `jax.jit`). + If `True` then an integration failure will raise an error. Note that the errors + are only reliably raised on CPUs. If on GPUs then the error may only be + printed to stderr, whilst on TPUs then the behaviour is undefined. If `False` then the returned solution object will have a `result` field indicating whether any failures occurred. @@ -634,7 +650,9 @@ def diffeqsolve( f"t0 with value {t0} and type {type(t0)}, " f"dt0 with value {dt0} and type {type(dt0)}" ) - error_if((t1 - t0) * dt0 < 0, msg) + with jax.ensure_compile_time_eval(): + pred = (t1 - t0) * dt0 < 0 + dt0 = eqxi.error_if(dt0, pred, msg) # Error checking term_leaves, term_structure = jtu.tree_flatten( @@ -677,7 +695,7 @@ def diffeqsolve( if isinstance(solver, Euler): raise ValueError( "An SDE should not be solved with adaptive step sizes with Euler's " - "method; it will not converge to the correct solution." + "method, as it may not converge to the correct solution." ) if is_unsafe_sde(terms): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): @@ -691,8 +709,6 @@ def diffeqsolve( # Allow setting e.g. t0 as an int with dt0 as a float. (We need consistent # types for JAX to be happy with the bounded_while_loop below.) - # Use compile-time-eval to avoid turning timelikes into spurious tracers, which - # inhibit optimisation via compile-time number-of-step inference. with jax.ensure_compile_time_eval(): timelikes = (jnp.array(0.0), t0, t1, dt0, saveat.ts) timelikes = [x for x in timelikes if x is not None] @@ -713,7 +729,6 @@ def _promote(yi): del timelikes, dtype # Normalises time: if t0 > t1 then flip things around. - # Once again use compile-time-eval to keep the timelikes non-tracer if possible. with jax.ensure_compile_time_eval(): direction = jnp.where(t0 < t1, 1, -1) t0 = t0 * direction @@ -736,22 +751,28 @@ def _promote(yi): # Error checking if saveat.ts is not None: - error_if( + saveat_ts = eqxi.error_if( + saveat.ts, saveat.ts[1:] < saveat.ts[:-1], "saveat.ts must be increasing or decreasing.", ) - error_if( - (saveat.ts > t1) | (saveat.ts < t0), "saveat.ts must lie between t0 and t1." + saveat_ts = eqxi.error_if( + saveat_ts, + (saveat.ts > t1) | (saveat.ts < t0), + "saveat.ts must lie between t0 and t1.", ) + saveat = eqx.tree_at(lambda s: s.ts, saveat, saveat_ts) # Initialise states tprev = t0 error_order = solver.error_order(terms) if controller_state is None: + passed_controller_state = False (tnext, controller_state) = stepsize_controller.init( terms, t0, t1, y0, dt0, args, solver.func, error_order ) else: + passed_controller_state = True if dt0 is None: (tnext, _) = stepsize_controller.init( terms, t0, t1, y0, dt0, args, solver.func, error_order @@ -760,7 +781,10 @@ def _promote(yi): tnext = t0 + dt0 tnext = jnp.minimum(tnext, t1) if solver_state is None: + passed_solver_state = False solver_state = solver.init(terms, t0, tnext, y0, args) + else: + passed_solver_state = True # Allocate memory to store output. out_size = 0 @@ -789,7 +813,7 @@ def _promote(yi): ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), y0) result = jnp.array(RESULTS.successful) if saveat.dense: - error_if(t0 == t1, "Cannot save dense output if t0 == t1") + t0 = eqxi.error_if(t0, t0 == t1, "Cannot save dense output if t0 == t1") if max_steps is None: raise ValueError( "`max_steps=None` is incompatible with `saveat.dense=True`" @@ -842,8 +866,10 @@ def _promote(yi): t1=t1, dt0=dt0, max_steps=max_steps, - throw=throw, init_state=init_state, + throw=throw, + passed_solver_state=passed_solver_state, + passed_controller_state=passed_controller_state, ) # @@ -887,23 +913,15 @@ def _promote(yi): t1 = t1 * direction # Store metadata - compiled_num_steps = aux_stats["compiled_num_steps"] stats = { "num_steps": final_state.num_steps, "num_accepted_steps": final_state.num_accepted_steps, "num_rejected_steps": final_state.num_rejected_steps, "max_steps": max_steps, - "compiled_num_steps": compiled_num_steps, + "compiled_num_steps": aux_stats["compiled_num_steps"], } result = final_state.result - error_index = unvmap_max(result) - branched_error_if( - throw & jnp.invert(is_okay(result)), - error_index, - RESULTS.reverse_lookup, - ) - - return Solution( + sol = Solution( t0=t0, t1=t1, ts=ts, @@ -915,3 +933,12 @@ def _promote(yi): controller_state=controller_state, made_jump=made_jump, ) + + error_index = eqxi.unvmap_max(result) + sol = eqxi.branched_error_if( + sol, + throw & jnp.invert(is_okay(result)), + error_index, + RESULTS.reverse_lookup, + ) + return sol diff --git a/diffrax/local_interpolation.py b/diffrax/local_interpolation.py index b9046ed9..bf41f915 100644 --- a/diffrax/local_interpolation.py +++ b/diffrax/local_interpolation.py @@ -5,15 +5,16 @@ import jax.numpy as jnp import jax.tree_util as jtu import numpy as np +from equinox.internal import ω from .custom_types import Array, PyTree, Scalar -from .misc import linear_rescale, ω +from .misc import linear_rescale from .path import AbstractPath class AbstractLocalInterpolation(AbstractPath): - t0: Scalar = field(init=True) - t1: Scalar = field(init=True) # override init=False on AbstractPath + t0: Scalar = field(init=True, repr=True) + t1: Scalar = field(init=True, repr=True) # override AbstractPath class LocalLinearInterpolation(AbstractLocalInterpolation): diff --git a/diffrax/misc/__init__.py b/diffrax/misc/__init__.py index d72e095b..4b35bc2a 100644 --- a/diffrax/misc/__init__.py +++ b/diffrax/misc/__init__.py @@ -1,14 +1,7 @@ -from .ad import ( - fixed_custom_jvp, - implicit_jvp, - nondifferentiable_input, - nondifferentiable_output, -) +from .ad import implicit_jvp from .bounded_while_loop import bounded_while_loop, HadInplaceUpdate -from .errors import branched_error_if, error_if from .misc import ( adjoint_rms_seminorm, - ContainerMeta, fill_forward, force_bitcast_convert_type, is_tuple_of_ints, @@ -17,7 +10,4 @@ rms_norm, split_by_tree, ) -from .nextafter import nextafter, prevbefore -from .omega import ω from .sde_kl_divergence import sde_kl_divergence -from .unvmap import unvmap_all, unvmap_any, unvmap_max diff --git a/diffrax/misc/ad.py b/diffrax/misc/ad.py index 721f038c..7ceff0d9 100644 --- a/diffrax/misc/ad.py +++ b/diffrax/misc/ad.py @@ -1,85 +1,12 @@ import functools as ft -from typing import Any import equinox as eqx import jax import jax.flatten_util as fu -import jax.interpreters.ad as ad -import jax.interpreters.batching as batching -import jax.interpreters.mlir as mlir -import jax.interpreters.xla as xla import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu -from ..custom_types import PyTree - - -# TODO: this will sometimes return False on a perturbed array, see JAX issue #9567. -# Correspondingly it should *not be used* until that is fixed. -# (The only use is in nondifferentiable_input, below, which will simply not raise -# errors quite as frequently as it should do -- not too bad.) -def is_perturbed(x: Any) -> bool: - if isinstance(x, jax.ad.JVPTracer): - return True - elif isinstance(x, jax.core.Tracer): - return any(is_perturbed(attr) for name, attr in x._contents()) - else: - return False - - -def nondifferentiable_input(x: PyTree, name: str) -> None: - if any(is_perturbed(xi) for xi in jtu.tree_leaves(x)): - raise ValueError(f"Cannot differentiate {name}.") - - -_nondifferentiable_output_p = jax.core.Primitive("nondifferentiable_output") - - -def _nondifferentiable_output_batch(x, batch_axes): - (x,) = x - (batch_axes,) = batch_axes - return nondifferentiable_output(x), batch_axes - - -def _nondifferentiable_output_jvp(primals, tangents): - (primals,) = primals - (tangents,) = tangents - return nondifferentiable_output(primals), nondifferentiable_output(tangents) - - -def _nondifferentiable_output_transpose(cts_in, _): - if isinstance(cts_in, ad.Zero): - return ad.Zero # the class, not an instance - else: - raise RuntimeError( - "Reverse-mode autodifferentiation is disabled for this operation." - ) - - -_nondifferentiable_output_p.def_impl(lambda x: x) -_nondifferentiable_output_p.def_abstract_eval(lambda x: x) -batching.primitive_batchers[ - _nondifferentiable_output_p -] = _nondifferentiable_output_batch -if hasattr(xla, "lower_fun"): - xla.register_translation( - _nondifferentiable_output_p, - xla.lower_fun(lambda x: x, multiple_results=False, new_style=True), - ) -mlir.register_lowering( - _nondifferentiable_output_p, - mlir.lower_fun(lambda x: x, multiple_results=False), -) -ad.primitive_jvps[_nondifferentiable_output_p] = _nondifferentiable_output_jvp -ad.primitive_transposes[ - _nondifferentiable_output_p -] = _nondifferentiable_output_transpose - - -def nondifferentiable_output(x: PyTree) -> PyTree: - return _nondifferentiable_output_p.bind(x) - class fixed_custom_jvp: """As jax.custom_jvp but works around JAX issue #9374.""" diff --git a/diffrax/misc/bounded_while_loop.py b/diffrax/misc/bounded_while_loop.py index d983625a..ae3b6c89 100644 --- a/diffrax/misc/bounded_while_loop.py +++ b/diffrax/misc/bounded_while_loop.py @@ -1,13 +1,13 @@ import math import equinox as eqx +import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from ..custom_types import Array -from .unvmap import unvmap_any def bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base=16): @@ -114,6 +114,7 @@ def _make_update(_new_val): def _body_fun(_val): inplace = lambda x: x inplace.pred = True + inplace.merge = lambda x: x _new_val = body_fun(_val, inplace) return jtu.tree_map( _make_update, @@ -230,7 +231,7 @@ def _call(_data): def _scan_fn(_data, _): _pred, _, _ = _data - _unvmap_pred = unvmap_any(_pred) + _unvmap_pred = eqxi.unvmap_any(_pred) return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None # Don't put checkpointing on the lowest level diff --git a/diffrax/misc/errors.py b/diffrax/misc/errors.py deleted file mode 100644 index 8594d5d0..00000000 --- a/diffrax/misc/errors.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Sequence, Union - -import jax.experimental.host_callback as hcb -import jax.lax as lax - -from ..custom_types import Array, Int -from .unvmap import unvmap_any - - -def error_if( - pred: Union[bool, Array[..., bool]], - msg: str, -) -> bool: - """For use as part of validating inputs. - Works even under JIT. - - Example: - @jax.jit - def f(x): - error_if(x < 0, "x must be >= 0") - - f(jax.numpy.array(-1)) - """ - branched_error_if(pred, 0, [msg]) - - -def branched_error_if( - pred: Union[bool, Array[..., bool]], - index: Int, - msgs: Sequence[str], -) -> bool: - def raises(_index): - raise RuntimeError(msgs[_index.item()]) - - pred = unvmap_any(pred) - lax.cond(pred, lambda: hcb.call(raises, index), lambda: None) diff --git a/diffrax/misc/misc.py b/diffrax/misc/misc.py index 379cdc3d..6ae6797e 100644 --- a/diffrax/misc/misc.py +++ b/diffrax/misc/misc.py @@ -27,32 +27,6 @@ def force_bitcast_convert_type(val, new_type): return lax.bitcast_convert_type(val, new_type) -class ContainerMeta(type): - def __new__(cls, name, bases, dict): - assert "reverse_lookup" not in dict - _dict = {} - reverse_lookup = [] - i = 0 - for key, value in dict.items(): - if key.startswith("__") and key.endswith("__"): - _dict[key] = value - else: - _dict[key] = i - reverse_lookup.append(value) - i += 1 - _dict["reverse_lookup"] = reverse_lookup - return super().__new__(cls, name, bases, _dict) - - def __instancecheck__(cls, instance): - return isinstance(instance, int) or super().__instancecheck__(instance) - - def __getitem__(cls, item): - return cls.reverse_lookup[item] - - def __len__(cls): - return len(cls.reverse_lookup) - - def _fill_forward( last_observed_yi: Array["channels":...], yi: Array["channels":...] # noqa: F821 ) -> Tuple[Array["channels":...], Array["channels":...]]: # noqa: F821 @@ -113,12 +87,26 @@ def rms_norm(x: PyTree) -> Scalar: x, _ = fu.ravel_pytree(x) if x.size == 0: return 0 - sqnorm = jnp.mean(x**2) - cond = sqnorm == 0 - # Double-where trick to avoid NaN gradients. - # See JAX issues #5039 and #1052. - _sqnorm = jnp.where(cond, 1.0, sqnorm) - return jnp.where(cond, 0.0, jnp.sqrt(_sqnorm)) + return _rms_norm(x) + + +@jax.custom_jvp +def _rms_norm(x): + x_sq = jnp.real(x * jnp.conj(x)) + return jnp.sqrt(jnp.mean(x_sq)) + + +@_rms_norm.defjvp +def _rms_norm_jvp(x, tx): + (x,) = x + (tx,) = tx + out = _rms_norm(x) + # Get zero gradient, rather than NaN gradient, in these cases + pred = (out == 0) | jnp.isinf(out) + numerator = jnp.where(pred, 0, x) + denominator = jnp.where(pred, 1, out * x.size) + t_out = jnp.dot(numerator / denominator, tx) + return out, t_out def adjoint_rms_seminorm(x: Tuple[PyTree, PyTree, PyTree, PyTree]) -> Scalar: diff --git a/diffrax/misc/nextafter.py b/diffrax/misc/nextafter.py deleted file mode 100644 index a157f3e6..00000000 --- a/diffrax/misc/nextafter.py +++ /dev/null @@ -1,27 +0,0 @@ -import jax -import jax.numpy as jnp - -from ..custom_types import Array - - -@jax.custom_jvp -def nextafter(x: Array) -> Array: - y = jnp.nextafter(x, jnp.inf) - # Flush denormal to normal. - # Our use for these is to handle jumps in the vector field. Typically that means - # there will be an "if x > cond" condition somewhere. However JAX uses DAZ - # (denormals-are-zero), which will cause this check to fail near zero: - # `jnp.nextafter(0, jnp.inf) > 0` gives `False`. - return jnp.where(x == 0, jnp.finfo(x.dtype).tiny, y) - - -nextafter.defjvps(lambda x_dot, _, __: x_dot) - - -@jax.custom_jvp -def prevbefore(x: Array) -> Array: - y = jnp.nextafter(x, jnp.NINF) - return jnp.where(x == 0, -jnp.finfo(x.dtype).tiny, y) - - -prevbefore.defjvps(lambda x_dot, _, __: x_dot) diff --git a/diffrax/misc/omega.py b/diffrax/misc/omega.py deleted file mode 100644 index dc24917d..00000000 --- a/diffrax/misc/omega.py +++ /dev/null @@ -1,231 +0,0 @@ -import operator -from typing import Optional - -import jax.numpy as jnp -import jax.tree_util as jtu - - -class _Metaω(type): - def __rpow__(cls, value): - return cls(value) - - -class ω(metaclass=_Metaω): - """Provides friendlier syntax for mapping with `jax.tree_util.tree_map`. - - !!! example - - ```python - (ω(a) + ω(b)).ω == jax.tree_util.tree_map(operator.add, a, b) - ``` - - !!! tip - - To minimise the number of brackets, the following `__rpow__` syntax can be - used: - - ```python - (a**ω + b**ω).ω == jax.tree_util.tree_map(operator.add, a, b) - ``` - - This is entirely equivalent to the above. - """ - - def __init__(self, value, is_leaf=None): - """ - **Arguments:** - - - `value`: The PyTree to wrap. - - `is_leaf`: An optional value for the `is_leaf` argument to - `jax.tree_util.tree_map`. - - !!! note - - The `is_leaf` argument cannot be set when using the `__rpow__` syntax for - initialisation. - """ - self.ω = value - self.is_leaf = is_leaf - - def __getitem__(self, item): - return ω( - jtu.tree_map(lambda x: x[item], self.ω, is_leaf=self.is_leaf), - is_leaf=self.is_leaf, - ) - - def call(self, fn): - return ω( - jtu.tree_map(fn, self.ω, is_leaf=self.is_leaf), - is_leaf=self.is_leaf, - ) - - @property - def at(self): - return _ωUpdateHelper(self.ω, self.is_leaf) - - -def _equal_code(fn1: Optional[callable], fn2: Optional[callable]): - """Checks whether fn1 and fn2 both have the same code. - - It's essentially impossible to see if two functions are equivalent, so this won't, - and isn't intended, to catch every possible difference between fn1 and fn2. But it - should at least catch the common case that `is_leaf` is specified for one input and - not specified for the other. - """ - sentinel1 = object() - sentinel2 = object() - code1 = getattr(getattr(fn1, "__code__", sentinel1), "co_code", sentinel2) - code2 = getattr(getattr(fn2, "__code__", sentinel1), "co_code", sentinel2) - return type(code1) == type(code2) and code1 == code2 - - -def _set_binary(base, name: str, op: callable) -> callable: - def fn(self, other): - if isinstance(other, ω): - if jtu.tree_structure(self.ω) != jtu.tree_structure(other.ω): - raise ValueError("PyTree structures must match.") - if not _equal_code(self.is_leaf, other.is_leaf): - raise ValueError("`is_leaf` must match.") - return ω( - jtu.tree_map(op, self.ω, other.ω, is_leaf=self.is_leaf), - is_leaf=self.is_leaf, - ) - elif isinstance(other, (bool, complex, float, int, jnp.ndarray)): - return ω( - jtu.tree_map(lambda x: op(x, other), self.ω, is_leaf=self.is_leaf), - is_leaf=self.is_leaf, - ) - else: - raise RuntimeError("Type of `other` not understood.") - - fn.__name__ = name - fn.__qualname__ = base.__qualname__ + "." + name - setattr(base, name, fn) - - -def _set_unary(base, name: str, op: callable) -> callable: - def fn(self): - return ω( - jtu.tree_map(op, self.ω, is_leaf=self.is_leaf), - is_leaf=self.is_leaf, - ) - - fn.__name__ = name - fn.__qualname__ = base.__qualname__ + "." + name - setattr(base, name, fn) - - -def _rev(op): - def __rev(x, y): - return op(y, x) - - return __rev - - -for (name, op) in [ - ("__add__", operator.add), - ("__sub__", operator.sub), - ("__mul__", operator.mul), - ("__matmul__", operator.matmul), - ("__truediv__", operator.truediv), - ("__floordiv__", operator.floordiv), - ("__mod__", operator.mod), - ("__pow__", operator.pow), - ("__lshift__", operator.lshift), - ("__rshift__", operator.rshift), - ("__and__", operator.and_), - ("__xor__", operator.xor), - ("__or__", operator.or_), - ("__radd__", _rev(operator.add)), - ("__rsub__", _rev(operator.sub)), - ("__rmul__", _rev(operator.mul)), - ("__rmatmul__", _rev(operator.matmul)), - ("__rtruediv__", _rev(operator.truediv)), - ("__rfloordiv__", _rev(operator.floordiv)), - ("__rmod__", _rev(operator.mod)), - ("__rpow__", _rev(operator.pow)), - ("__rlshift__", _rev(operator.lshift)), - ("__rrshift__", _rev(operator.rshift)), - ("__rand__", _rev(operator.and_)), - ("__rxor__", _rev(operator.xor)), - ("__ror__", _rev(operator.or_)), - ("__lt__", operator.lt), - ("__le__", operator.le), - ("__eq__", operator.eq), - ("__ne__", operator.ne), - ("__gt__", operator.gt), - ("__ge__", operator.ge), -]: - _set_binary(ω, name, op) - - -for (name, op) in [ - ("__neg__", operator.neg), - ("__pos__", operator.pos), - ("__abs__", operator.abs), - ("__invert__", operator.invert), -]: - _set_unary(ω, name, op) - - -class _ωUpdateHelper: - def __init__(self, value, is_leaf): - self.value = value - self.is_leaf = is_leaf - - def __getitem__(self, item): - return _ωUpdateRef(self.value, item, self.is_leaf) - - -class _ωUpdateRef: - def __init__(self, value, item, is_leaf): - self.value = value - self.item = item - self.is_leaf = is_leaf - - def get(self, **kwargs): - value, item = self.ω - return value.at[item].get(**kwargs) - - -def _set_binary_at(base, name: str, op: callable) -> callable: - def fn(self, other): - if isinstance(other, ω): - if jtu.tree_structure(self.value) != jtu.tree_structure(other.ω): - raise ValueError("PyTree structures must match.") - if not _equal_code(self.is_leaf, other.is_leaf): - raise ValueError("is_leaf specifications must match.") - return ω( - jtu.tree_map( - lambda x, y: op(x, self.item, y), - self.value, - other.ω, - is_leaf=self.is_leaf, - ), - is_leaf=self.is_leaf, - ) - elif isinstance(other, (bool, complex, float, int, jnp.ndarray)): - return ω( - jtu.tree_map( - lambda x: op(x, self.item, other), self.value, is_leaf=self.is_leaf - ), - is_leaf=self.is_leaf, - ) - else: - raise RuntimeError("Type of `other` not understood.") - - fn.__name__ = name - fn.__qualname__ = base.__qualname__ + "." + name - setattr(base, name, fn) - - -for (name, op) in [ - ("set", lambda x, y, z, **kwargs: x.at[y].set(z, **kwargs)), - ("add", lambda x, y, z, **kwargs: x.at[y].add(z, **kwargs)), - ("multiply", lambda x, y, z, **kwargs: x.at[y].multiply(z, **kwargs)), - ("divide", lambda x, y, z, **kwargs: x.at[y].divide(z, **kwargs)), - ("power", lambda x, y, z, **kwargs: x.at[y].power(z, **kwargs)), - ("min", lambda x, y, z, **kwargs: x.at[y].min(z, **kwargs)), - ("max", lambda x, y, z, **kwargs: x.at[y].max(z, **kwargs)), -]: - _set_binary_at(_ωUpdateRef, name, op) diff --git a/diffrax/misc/unvmap.py b/diffrax/misc/unvmap.py deleted file mode 100644 index d344016f..00000000 --- a/diffrax/misc/unvmap.py +++ /dev/null @@ -1,111 +0,0 @@ -import jax -import jax.interpreters.batching as batching -import jax.interpreters.mlir as mlir -import jax.interpreters.xla as xla -import jax.numpy as jnp - - -# unvmap_all - -_unvmap_all_p = jax.core.Primitive("unvmap_all") - - -def unvmap_all(x): - return _unvmap_all_p.bind(x) - - -def _unvmap_all_impl(x): - return jnp.all(x) - - -def _unvmap_all_abstract_eval(x): - return jax.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) - - -def _unvmap_all_batch(x, batch_axes): - (x,) = x - return unvmap_all(x), batching.not_mapped - - -_unvmap_all_p.def_impl(_unvmap_all_impl) -_unvmap_all_p.def_abstract_eval(_unvmap_all_abstract_eval) -batching.primitive_batchers[_unvmap_all_p] = _unvmap_all_batch -if hasattr(xla, "lower_fun"): - xla.register_translation( - _unvmap_all_p, - xla.lower_fun(_unvmap_all_impl, multiple_results=False, new_style=True), - ) -mlir.register_lowering( - _unvmap_all_p, - mlir.lower_fun(_unvmap_all_impl, multiple_results=False), -) - -# unvmap_any - -_unvmap_any_p = jax.core.Primitive("unvmap_any") - - -def unvmap_any(x): - return _unvmap_any_p.bind(x) - - -def _unvmap_any_impl(x): - return jnp.any(x) - - -def _unvmap_any_abstract_eval(x): - return jax.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) - - -def _unvmap_any_batch(x, batch_axes): - (x,) = x - return unvmap_any(x), batching.not_mapped - - -_unvmap_any_p.def_impl(_unvmap_any_impl) -_unvmap_any_p.def_abstract_eval(_unvmap_any_abstract_eval) -batching.primitive_batchers[_unvmap_any_p] = _unvmap_any_batch -if hasattr(xla, "lower_fun"): - xla.register_translation( - _unvmap_any_p, - xla.lower_fun(_unvmap_any_impl, multiple_results=False, new_style=True), - ) -mlir.register_lowering( - _unvmap_any_p, - mlir.lower_fun(_unvmap_any_impl, multiple_results=False), -) - -# unvmap_max - -_unvmap_max_p = jax.core.Primitive("unvmap_max") - - -def unvmap_max(x): - return _unvmap_max_p.bind(x) - - -def _unvmap_max_impl(x): - return jnp.max(x) - - -def _unvmap_max_abstract_eval(x): - return jax.ShapedArray(shape=(), dtype=x.dtype) - - -def _unvmap_max_batch(x, batch_axes): - (x,) = x - return unvmap_max(x), batching.not_mapped - - -_unvmap_max_p.def_impl(_unvmap_max_impl) -_unvmap_max_p.def_abstract_eval(_unvmap_max_abstract_eval) -batching.primitive_batchers[_unvmap_max_p] = _unvmap_max_batch -if hasattr(xla, "lower_fun"): - xla.register_translation( - _unvmap_max_p, - xla.lower_fun(_unvmap_max_impl, multiple_results=False, new_style=True), - ) -mlir.register_lowering( - _unvmap_max_p, - mlir.lower_fun(_unvmap_max_impl, multiple_results=False), -) diff --git a/diffrax/path.py b/diffrax/path.py index f599f03c..3ee3484f 100644 --- a/diffrax/path.py +++ b/diffrax/path.py @@ -36,8 +36,8 @@ def evaluate(self, t0, t1=None, left=True): ``` """ - t0: Scalar = field(init=False) - t1: Scalar = field(init=False) + t0: Scalar = field(init=False, repr=False) + t1: Scalar = field(init=False, repr=False) @abc.abstractmethod def evaluate( diff --git a/diffrax/solution.py b/diffrax/solution.py index 2545d787..12f3805e 100644 --- a/diffrax/solution.py +++ b/diffrax/solution.py @@ -1,15 +1,15 @@ from dataclasses import field from typing import Any, Dict, Optional +import equinox.internal as eqxi import jax.numpy as jnp from .custom_types import Array, Bool, PyTree, Scalar from .global_interpolation import DenseInterpolation -from .misc import ContainerMeta from .path import AbstractPath -class RESULTS(metaclass=ContainerMeta): +class RESULTS(metaclass=eqxi.ContainerMeta): successful = "" discrete_terminating_event_occurred = ( "Terminating solve because a discrete event occurred." @@ -85,8 +85,8 @@ class Solution(AbstractPath): must allocate enough space for the maximum possible number of steps. """ - t0: Scalar = field(init=True) - t1: Scalar = field(init=True) # override init=False in AbstractPath + t0: Scalar = field(init=True, repr=True) + t1: Scalar = field(init=True, repr=True) # override AbstractPath ts: Optional[Array["times"]] # noqa: F821 ys: Optional[PyTree["times", ...]] # noqa: F821 interpolation: Optional[DenseInterpolation] diff --git a/diffrax/solver/dopri8.py b/diffrax/solver/dopri8.py index e47ea90a..77ba0ab8 100644 --- a/diffrax/solver/dopri8.py +++ b/diffrax/solver/dopri8.py @@ -3,10 +3,11 @@ import jax import jax.numpy as jnp import numpy as np +from equinox.internal import ω from ..custom_types import Array, PyTree, Scalar from ..local_interpolation import AbstractLocalInterpolation -from ..misc import linear_rescale, ω +from ..misc import linear_rescale from .base import vector_tree_dot from .runge_kutta import AbstractERK, ButcherTableau diff --git a/diffrax/solver/euler.py b/diffrax/solver/euler.py index 682c2423..eec2f79a 100644 --- a/diffrax/solver/euler.py +++ b/diffrax/solver/euler.py @@ -1,10 +1,10 @@ from typing import Tuple import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractItoSolver diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py index 3722456b..b8f865ca 100644 --- a/diffrax/solver/euler_heun.py +++ b/diffrax/solver/euler_heun.py @@ -1,10 +1,10 @@ from typing import Tuple import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractStratonovichSolver diff --git a/diffrax/solver/implicit_euler.py b/diffrax/solver/implicit_euler.py index 5b25bdd2..582b3c53 100644 --- a/diffrax/solver/implicit_euler.py +++ b/diffrax/solver/implicit_euler.py @@ -1,10 +1,10 @@ from typing import Tuple import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractImplicitSolver diff --git a/diffrax/solver/leapfrog_midpoint.py b/diffrax/solver/leapfrog_midpoint.py index d4c57a86..b563f601 100644 --- a/diffrax/solver/leapfrog_midpoint.py +++ b/diffrax/solver/leapfrog_midpoint.py @@ -1,10 +1,10 @@ from typing import Tuple import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractSolver diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py index a17b030f..9264acdc 100644 --- a/diffrax/solver/milstein.py +++ b/diffrax/solver/milstein.py @@ -3,10 +3,10 @@ import jax import jax.numpy as jnp import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractItoSolver, AbstractStratonovichSolver diff --git a/diffrax/solver/reversible_heun.py b/diffrax/solver/reversible_heun.py index 429f2b80..eeb86552 100644 --- a/diffrax/solver/reversible_heun.py +++ b/diffrax/solver/reversible_heun.py @@ -2,10 +2,10 @@ import jax.lax as lax import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index f5921620..9d014542 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -3,14 +3,15 @@ from typing import Optional, Tuple import equinox as eqx +import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu import numpy as np +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar -from ..misc import ContainerMeta, ω from ..solution import is_okay, RESULTS, update_result from ..term import AbstractTerm from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot @@ -130,7 +131,7 @@ def __post_init__(self): """ -class CalculateJacobian(metaclass=ContainerMeta): +class CalculateJacobian(metaclass=eqxi.ContainerMeta): """An enumeration of possible ways a Runga--Kutta method may wish to calculate a Jacobian. @@ -206,6 +207,46 @@ def tableau(self) -> ButcherTableau: def calculate_jacobian(self) -> CalculateJacobian: pass + def _first(self, terms, t0, t1, y0, args): + vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) + implicit_first_stage = ( + self.tableau.a_diagonal is not None and self.tableau.a_diagonal[0] != 0 + ) + # The gamut of conditions under which we need to evaluate `f0` or `k0`. + # + # If we're computing the Jacobian at the start of the step, then we + # need this as a linearisation point. + # + # If the first stage is implicit, then we need this as a predictor for + # where to start iterating from. + # + # If we're not scanning stages then we're definitely not deferring this + # evaluation to the scan loop, so get it done now. + need_f0_or_k0 = ( + self.calculate_jacobian == CalculateJacobian.every_step + or implicit_first_stage + or not self.scan_stages + ) + fsal = self.tableau.fsal + if fsal and vf_expensive: + # If the vector field is expensive then we want to use vf_prods instead. + # FSAL implies evaluating just the vector field, since we need to contract + # the same vector field evaluation against two different controls. + # + # But "evaluating just the vector field" is, as just established, expensive. + fsal = False + if fsal and self.scan_stages and not need_f0_or_k0: + # If we're scanning stages then we'd like to disable FSAL. + # FSAL implies evaluating the vector field in `init` as well as in `step`. + # But `scan_stages` is a please-compile-faster flag, so we should avoid the + # extra tracing. + # + # However we disable-the-disabling if `need_f0_or_k0`, since in this case + # we evaluate `f0` or `k0` anyway, so it wouldn't help. So we might as well + # take advantage of the runtime benefits of FSAL. + fsal = False + return vf_expensive, implicit_first_stage, need_f0_or_k0, fsal + def func( self, terms: AbstractTerm, @@ -223,8 +264,7 @@ def init( y0: PyTree, args: PyTree, ) -> _SolverState: - vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) - fsal = self.tableau.fsal and not vf_expensive + _, _, _, fsal = self._first(terms, t0, t1, y0, args) if fsal: return terms.vf(t0, y0, args) else: @@ -277,14 +317,12 @@ def step( # e.g. we need `ks` to perform dense interpolation if needed. # - _vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) _implicit_later_stages = self.tableau.a_diagonal is not None and any( self.tableau.a_diagonal[1:] != 0 ) - implicit_first_stage = ( - self.tableau.a_diagonal is not None and self.tableau.a_diagonal[0] != 0 + _vf_expensive, implicit_first_stage, need_f0_or_k0, fsal = self._first( + terms, t0, t1, y0, args ) - fsal = self.tableau.fsal and not _vf_expensive ssal = self.tableau.ssal if _implicit_later_stages and fsal: use_fs = True @@ -308,28 +346,18 @@ def step( if fsal: f0 = solver_state if not use_fs: - k0 = lax.cond( - made_jump, - lambda _: terms.vf_prod(t0, y0, args, control), - lambda _: terms.prod(f0, control), # noqa: F821 - None, - ) + # `made_jump` can be a tracer, hence the `is`. + if made_jump is False: + # Fast-path for compilation in the common case. + k0 = terms.prod(f0, control) + else: + k0 = lax.cond( + made_jump, + lambda: terms.vf_prod(t0, y0, args, control), + lambda: terms.prod(f0, control), # noqa: F821 + ) else: - if ( - self.calculate_jacobian == CalculateJacobian.every_step - or implicit_first_stage - or not self.scan_stages - ): - # The gamut of conditions under which we need to evaluate `f0` or `k0`. - # - # If we're computing the Jacobian at the start of the step, then we - # need this as a linearisation point. - # - # If the first stage is implicit, then we need this as a predictor for - # where to start iterating from. - # - # If we're not scanning stages then we're definitely not deferring this - # evaluation to the scan loop, so get it done now. + if need_f0_or_k0: if use_fs: f0 = terms.vf(t0, y0, args) else: diff --git a/diffrax/solver/semi_implicit_euler.py b/diffrax/solver/semi_implicit_euler.py index ec9d7cf7..d3a09ae4 100644 --- a/diffrax/solver/semi_implicit_euler.py +++ b/diffrax/solver/semi_implicit_euler.py @@ -1,10 +1,10 @@ from typing import Tuple import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation -from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractSolver diff --git a/diffrax/solver/tsit5.py b/diffrax/solver/tsit5.py index f685055d..8322aad7 100644 --- a/diffrax/solver/tsit5.py +++ b/diffrax/solver/tsit5.py @@ -2,10 +2,11 @@ import jax.numpy as jnp import numpy as np +from equinox.internal import ω from ..custom_types import Array, PyTree, Scalar from ..local_interpolation import AbstractLocalInterpolation -from ..misc import linear_rescale, ω +from ..misc import linear_rescale from .base import vector_tree_dot from .runge_kutta import AbstractERK, ButcherTableau diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 33996deb..9750976b 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -2,13 +2,15 @@ from typing import Callable, Optional, Tuple import equinox as eqx +import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu +from equinox.internal import ω from ..custom_types import Array, Bool, PyTree, Scalar -from ..misc import nextafter, prevbefore, rms_norm, ω +from ..misc import rms_norm from ..solution import RESULTS from ..solver import AbstractImplicitSolver, AbstractSolver from ..term import AbstractTerm @@ -368,7 +370,7 @@ def init( t1 = self._clip_step_ts(t0, t0 + dt0) t1, jump_next_step = self._clip_jump_ts(t0, t1) - return t1, (jump_next_step, at_dtmin, dt0, jnp.inf, jnp.inf) + return t1, (jump_next_step, at_dtmin, dt0, 1.0, 1.0) def adapt_step_size( self, @@ -470,46 +472,47 @@ def _scale(_y0, _y1_candidate, _y_error): keep_step = scaled_error < 1 if self.dtmin is not None: keep_step = keep_step | at_dtmin - # Double-where trick to avoid NaN gradients. - # See JAX issues #5039 and #1052. - _cond = scaled_error == 0 - _scaled_error = jnp.where(_cond, 1, scaled_error) - _inv_scaled_error = 1 / _scaled_error - inv_scaled_error = jnp.where(_cond, jnp.inf, _inv_scaled_error) + # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. + inv_scaled_error = 1 / jnp.asarray(scaled_error) + inv_scaled_error = lax.stop_gradient( + inv_scaled_error + ) # See note in init above. + # Note: if you ever remove this lax.stop_gradient, then you'll need to do a lot + # of work to get safe gradients through these operations. + # When `inv_scaled_error` has a (non-symbolic) zero cotangent, and `y_error` + # is either zero or inf, then we get a `0 * inf = nan` on the backward pass. # # Adjust next step size # - # The [prev_]prev_inv_scaled_error can be inf from `self.init(...)`. - # In this case we shouldn't let the extra factors kick in until we've made - # some more steps, so we set the factor to one. - # They can also be inf from the previous `self.adapt_step_size(...)`. In this - # case we had zero estimated error on the previous step and will have already - # increased stepsize by `self.factormax` then. So set the factor to one now. - _inf_to_one = lambda x: jnp.where(x == jnp.inf, 1, x) _zero_coeff = lambda c: isinstance(c, (int, float)) and c == 0 coeff1 = (self.icoeff + self.pcoeff + self.dcoeff) / error_order coeff2 = -(self.pcoeff + 2 * self.dcoeff) / error_order coeff3 = self.dcoeff / error_order factor1 = 1 if _zero_coeff(coeff1) else inv_scaled_error ** coeff1 - factor2 = ( - 1 if _zero_coeff(coeff2) else _inf_to_one(prev_inv_scaled_error) ** coeff2 - ) - factor3 = ( - 1 - if _zero_coeff(coeff3) - else _inf_to_one(prev_prev_inv_scaled_error) ** coeff3 - ) + factor2 = 1 if _zero_coeff(coeff2) else prev_inv_scaled_error ** coeff2 + factor3 = 1 if _zero_coeff(coeff3) else prev_prev_inv_scaled_error ** coeff3 factormin = jnp.where(keep_step, 1, self.factormin) factor = jnp.clip( self.safety * factor1 * factor2 * factor3, a_min=factormin, a_max=self.factormax, ) - factor = lax.stop_gradient(factor) # See note in init above. + # Once again, see above. In case we have gradients on {i,p,d}coeff. + # (Probably quite common for them to have zero tangents if passed across + # a grad API boundary as part of a larger model.) + factor = lax.stop_gradient(factor) + factor = eqxi.nondifferentiable(factor) dt = prev_dt * factor + # E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0, + # so factor=factormin, and we shrunk our step. + # If we're using a PI or PID controller we shouldn't then force shrinking on + # the next or next two steps as well! + pred = (inv_scaled_error == 0) | jnp.isinf(inv_scaled_error) + inv_scaled_error = jnp.where(pred, 1, inv_scaled_error) + # # Clip next step size based on dtmin/dtmax # @@ -529,13 +532,13 @@ def _scale(_y0, _y1_candidate, _y_error): # Clip next step size based on step_ts/jump_ts # - if jnp.issubdtype(t1.dtype, jnp.inexact): + if jnp.issubdtype(jnp.result_type(t1), jnp.inexact): # Two nextafters. If made_jump then t1 = prevbefore(jump location) # so now _t1 = nextafter(jump location) # This is important because we don't know whether or not the jump is as a # result of a left- or right-discontinuity, so we have to skip the jump # location altogether. - _t1 = jnp.where(made_jump, nextafter(nextafter(t1)), t1) + _t1 = jnp.where(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1) else: _t1 = t1 next_t0 = jnp.where(keep_step, _t1, t0) @@ -597,7 +600,7 @@ def _clip_jump_ts(self, t0: Scalar, t1: Scalar) -> Tuple[Scalar, Array[(), bool] raise ValueError( f"jump_ts must be floating point, not {self.jump_ts.dtype}" ) - if not jnp.issubdtype(t1.dtype, jnp.inexact): + if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): raise ValueError( "t0, t1, dt0 must be floating point when specifying jump_t. Got " f"{t1.dtype}." @@ -607,7 +610,7 @@ def _clip_jump_ts(self, t0: Scalar, t1: Scalar) -> Tuple[Scalar, Array[(), bool] next_made_jump = t0_index < t1_index t1 = jnp.where( next_made_jump, - prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), + eqxi.prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), t1, ) return t1, next_made_jump diff --git a/diffrax/step_size_controller/constant.py b/diffrax/step_size_controller/constant.py index 7f905d54..7eaf5605 100644 --- a/diffrax/step_size_controller/constant.py +++ b/diffrax/step_size_controller/constant.py @@ -1,10 +1,10 @@ from typing import Callable, Optional, Sequence, Tuple, Union +import equinox.internal as eqxi import jax import jax.numpy as jnp from ..custom_types import Array, Int, PyTree, Scalar -from ..misc import error_if from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractStepSizeController @@ -93,7 +93,8 @@ def __post_init__(self): def wrap(self, direction: Scalar): ts = self.ts * direction # Only tested after we've set the direction. - error_if( + ts = eqxi.error_if( + ts, ts[1:] <= ts[:-1], "`StepTo(ts=...)` must be strictly increasing (or strictly decreasing if " "t0 > t1).", @@ -117,11 +118,12 @@ def init( "`dt0` should be `None`. Step location is already determined " f"by {type(self).__name__}(ts=...).", ) - error_if( + ts = eqxi.error_if( + self.ts, (t0 != self.ts[0]) | (t1 != self.ts[-1]), "Must have `t0==ts[0]` and `t1==ts[-1]`.", ) - return self.ts[1], 2 + return ts[1], 2 def adapt_step_size( self, diff --git a/diffrax/term.py b/diffrax/term.py index de557070..8ee8b014 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -7,9 +7,9 @@ import jax.numpy as jnp import jax.tree_util as jtu import numpy as np +from equinox.internal import ω from .custom_types import Array, PyTree, Scalar -from .misc import ω from .path import AbstractPath diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 39cef0d4..41c5a742 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -29,7 +29,7 @@ There are multiple ways to backpropagate through a differential equation (to com ::: diffrax.NoAdjoint selection: - members: false + members: false ::: diffrax.ImplicitAdjoint selection: diff --git a/docs/devdocs/omega.md b/docs/devdocs/omega.md deleted file mode 100644 index 73b40e51..00000000 --- a/docs/devdocs/omega.md +++ /dev/null @@ -1,49 +0,0 @@ -# Tree mapping with ω - -Looking through the code for the solvers, you may notice a "ω" that keeps popping up, in expressions like: -```python -(ω(x) + ω(y)).ω -``` -or the equivalent -```python -(x**ω + y**ω).ω -``` -which is just a different (fewer-bracket-using) syntax. - -## Usage - -The above are equivalent to just: -```python -jax.tree_util.tree_map(lambda a, b: a + b, x, y) -``` -and are designed just to be a convenient syntax for broadcasting operations over a PyTree. - -Other operations are of course defined: `ω` understands several of the built-in Python operators, including addition, subtraction, matrix multiplication etc. - -!!! tip - - As a convention, we both structure and destructure `ω` on a single line; we never assign a variable that is `ω`-wrapped. Passing `ω`-variables around starts to feel a bit too magic. - -!!! warning - - Note that when doing e.g. `a + ω(b)`, with the `ω` on the right, then things will probably break if `a` is a NumPy array. This is because the overload `a.__add__(ω(b))` is checked before `ω(b).__radd__(a)`, and NumPy will accept pretty much anything. The fix is to wrap `a` in a `jnp.ndarray` (which correctly raises `NotImplemented` instead). - -## Commentary - -### Non-goals - -Making anything like `jax.numpy.maximum(x**ω, y**ω)` work is not a goal for `ω`. Just use a regular `jax.tree_util.tree_map` in these situtions. `ω` only aims to support overloadable Python operations, and as a convenience single-argument functions via e.g. `ω(x).call(jax.numpy.abs)`. - -### On syntax - -The syntax might look a little bit odd. The rationale is as follows: - -- A single letter `ω` is used to avoid taking up too much space, so as to keep the terse syntax that e.g. `x + y` provides. -- We use a Greek letter, instead of the more typical Latin characters, to aid visual identification and minimise visual noise. - - Set up an alternate Greek keyboard if you haven't already. (The author is a mathematician and therefore already has this configured...) -- We support the `... ** ω` operation, as well as `ω(...)`, to minimise the number of brackets. For some expressions this reduces visual noise. -- Specifically the `**` operation is used as it has a high precedence -- in particular higher than arithmetic operations. It also pairs visually conveniently with `.ω` (the unwrapping operation): `**` is two high dots, and `.` is one low dot. - -### See also - -See also [tree-math](https://github.com/google/tree-math) for a similar project with a similar idea. One key difference is that `ω` will broadcast leaves together, whilst `tree-math` will not (and is instead meant to feel like a one-dimensional vector in its usage). diff --git a/docs/further_details/faq.md b/docs/further_details/faq.md index fbb0593d..66503dbc 100644 --- a/docs/further_details/faq.md +++ b/docs/further_details/faq.md @@ -4,6 +4,7 @@ - Use `scan_stages=True`, e.g. `Tsit5(scan_stages=True)`. This is supported for all Runge--Kutta methods. This will substantially reduce compile time at the expense of a slightly slower run time. - Set `dt0=`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time. +- It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s). eg. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime. ### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour. @@ -23,7 +24,7 @@ diffeqsolve( ) ``` -In practice, `Tsit5` is usually a better solver than `Dopri5`. And the default adjoint method (`RecursiveCheckpointAdjoint`) is usually a better choice than `BacksolveAdjoint`. +In practice, [`diffrax.Tsit5`][] is usually a better solver than [`diffrax.Dopri5`][]. And the default adjoint method ([`diffrax.DirectAdjoint`][]) is usually a better choice than [`diffrax.BacksolveAdjoint`][]. ### I'm getting a `CustomVJPException`. diff --git a/mkdocs.yml b/mkdocs.yml index df3dbf67..ad3020d3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -132,5 +132,4 @@ nav: - Developer Documentation: - 'devdocs/predictor_dirk.md' - 'devdocs/adjoint_commutative_noise.md' - - 'devdocs/omega.md' - 'devdocs/bounded_while_loop.md' diff --git a/setup.py b/setup.py index 624ef7d4..c7329ad3 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.7" -install_requires = ["jax>=0.3.4", "equinox>=0.5.4"] +install_requires = ["jax>=0.3.4", "equinox>=0.9.1"] setuptools.setup( name=name, diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 4f52123e..6f4dd494 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -1,5 +1,7 @@ import diffrax +import equinox as eqx import jax.numpy as jnp +import jax.tree_util as jtu def test_step_ts(): @@ -65,3 +67,26 @@ def run(**kwargs): assert sol.result == 0 assert 3.5 in sol.ts assert 8 in sol.ts + + +def test_backprop(): + @eqx.filter_jit + @eqx.filter_grad + def run(ys, controller, state): + y0, y1_candidate, y_error = ys + _, tprev, tnext, _, state, _ = controller.adapt_step_size( + 0, 1, y0, y1_candidate, None, y_error, 5, state + ) + return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) + + y0 = jnp.array(1.0) + y1_candidate = jnp.array(2.0) + term = diffrax.ODETerm(lambda t, y, args: -y) + solver = diffrax.Tsit5() + stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-4) + _, state = stepsize_controller.init(term, 0, 1, y0, 0.1, None, solver.func, 5) + + for y_error in (jnp.array(0.0), jnp.array(3.0), jnp.array(jnp.inf)): + ys = (y0, y1_candidate, y_error) + grads = run(ys, stepsize_controller, state) + assert not any(jnp.isnan(grad).any() for grad in grads) diff --git a/test/test_adjoint.py b/test/test_adjoint.py index b408c5d4..3bf48ec6 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -24,7 +24,7 @@ def fn(y0): sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, adjoint=adjoint) return jnp.sum(sol.ys) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): jax.grad(fn)(1.0) primal, dual = jax.jvp(fn, (1.0,), (1.0,)) @@ -45,7 +45,7 @@ def __call__(self, t, y, args): return jnp.stack([dya, dyb]) -def test_backsolve(getkey): +def test_against(getkey): y0 = jnp.array([0.9, 5.4]) args = (0.1, -1) term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) @@ -105,21 +105,22 @@ def _convert_float0(x): if t0 is False and t1 is False and ts is None: continue saveat = diffrax.SaveAt(t0=t0, t1=t1, ts=ts) - true_grads = _run_grad_int( + + recursive_grads = _run_grad( + diff, saveat, diffrax.RecursiveCheckpointAdjoint() + ) + backsolve_grads = _run_grad(diff, saveat, diffrax.BacksolveAdjoint()) + assert shaped_allclose(recursive_grads, backsolve_grads, atol=1e-5) + + recursive_grads = _run_grad_int( y0__args__term, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad_int( y0__args__term, saveat, diffrax.BacksolveAdjoint() ) - true_grads = jtu.tree_map(_convert_float0, true_grads) + recursive_grads = jtu.tree_map(_convert_float0, recursive_grads) backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads) - assert shaped_allclose(true_grads, backsolve_grads) - - true_grads = _run_grad( - diff, saveat, diffrax.RecursiveCheckpointAdjoint() - ) - backsolve_grads = _run_grad(diff, saveat, diffrax.BacksolveAdjoint()) - assert shaped_allclose(true_grads, backsolve_grads) + assert shaped_allclose(recursive_grads, backsolve_grads, atol=1e-5) def test_adjoint_seminorm(): diff --git a/test/test_brownian.py b/test/test_brownian.py index f22fece5..8e23a76c 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -25,40 +25,27 @@ def test_shape_and_dtype(ctr, getkey): t1 = 2 shapes = ( + (), (0,), - ( - 1, - 0, - ), + (1, 0), (2,), (3, 4), (1, 2, 3, 4), { "a": (1,), - "b": ( - 2, - 3, - ), + "b": (2, 3), }, ( + (1, 2), ( - 1, - 2, - ), - ( - ( - 3, - 4, - ), - ( - 5, - 6, - ), + (3, 4), + (5, 6), ), ), ) dtypes = ( + None, None, None, jnp.float16, diff --git a/test/test_detest.py b/test/test_detest.py index dd8e7952..666125e5 100644 --- a/test/test_detest.py +++ b/test/test_detest.py @@ -15,6 +15,7 @@ import jax.flatten_util as fu import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import pytest import scipy.integrate as integrate @@ -437,13 +438,13 @@ def _test(solver, problems, higher): y1 = jtu.tree_map(lambda yi: yi[0], sol.ys) scipy_y0, unravel = fu.ravel_pytree(init) - scipy_y0 = scipy_y0.to_py() + scipy_y0 = np.asarray(scipy_y0) def scipy_fn(t, y): y = unravel(y) dy = vector_field(t, y, None) dy, _ = fu.ravel_pytree(dy) - return dy.to_py() + return np.asarray(dy) scipy_sol = integrate.solve_ivp( scipy_fn, diff --git a/test/test_integrate.py b/test/test_integrate.py index 4514556e..41d897e6 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -9,7 +9,7 @@ import jax.tree_util as jtu import pytest import scipy.stats -from diffrax.misc import ω +from equinox.internal import ω from .helpers import ( all_ode_solvers, diff --git a/test/test_misc.py b/test/test_misc.py index 7443bbc0..7f60963a 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -1,10 +1,7 @@ import diffrax -import jax import jax.numpy as jnp -import jax.tree_util as jtu -import pytest -from .helpers import random_pytree, shaped_allclose, treedefs +from .helpers import shaped_allclose def test_fill_forward(): @@ -12,145 +9,3 @@ def test_fill_forward(): out_ = jnp.array([jnp.nan, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]) fill_in = diffrax.misc.fill_forward(in_[:, None]) assert shaped_allclose(fill_in, out_[:, None], equal_nan=True) - - -def test_ω_add_mul(getkey): - # ω(...) initialisation - ω = diffrax.misc.ω - a = [0, 1] - b = [1, 2] - c = (ω(a) + ω(b)).ω - assert c == [1, 3] - - # ...**ω initialisation - for treedef in treedefs: - a = b = c = random_pytree(getkey(), treedef) - - e1 = (a**ω * 2 + b**ω * c**ω - 3).ω - e2 = jtu.tree_map(lambda ai, bi, ci: ai * 2 + bi * ci - 3, a, b, c) - assert shaped_allclose(e1, e2) - - -def test_ω_inplace(getkey): - ω = diffrax.misc.ω - for treedef in treedefs: - a = random_pytree(getkey(), treedef) - b1 = ω(a).at[()].set(3).ω - b2 = jtu.tree_map(lambda ai: ai.at[()].set(3), a) - assert shaped_allclose(b1, b2) - - a2 = jtu.tree_map(lambda x: x + 1, a) - - b3 = ω(a).at[()].set(ω(a2)).ω - b4 = jtu.tree_map(lambda ai, a2i: ai.at[()].set(a2i[()]), a, a2) - assert shaped_allclose(b3, b4) - - -def test_ω_is_leaf(getkey): - ω = diffrax.misc.ω - for treedef in treedefs: - a = b = random_pytree(getkey(), treedef) - with pytest.raises(ValueError): - ω(a) + ω(b, is_leaf=lambda x: isinstance(x, int)) - with pytest.raises(ValueError): - ω(a, is_leaf=lambda x: isinstance(x, int)) + ω(b) - with pytest.raises(ValueError): - ω(a, is_leaf=lambda x: isinstance(x, int)) + ω( - b, is_leaf=lambda x: isinstance(x, (int, str)) - ) - - out = ω(a, is_leaf=lambda x: isinstance(x, int)) + ω( - b, is_leaf=lambda x: isinstance(x, int) - ) - assert out.is_leaf(4) - assert not out.is_leaf("hi") - - b = ω(a, is_leaf=lambda x: isinstance(x, int)).at[()].set(3) - assert out.is_leaf(4) - assert not out.is_leaf("hi") - - a2 = jtu.tree_map(lambda x: x + 1, a) - - c = ( - ω(a, is_leaf=lambda x: isinstance(x, int)) - .at[()] - .set(ω(a2, is_leaf=lambda x: isinstance(x, int))) - ) - assert c.is_leaf(4) - assert not c.is_leaf("hi") - - with pytest.raises(ValueError): - ω(a, is_leaf=lambda x: isinstance(x, int)).at[()].set(ω(a2)) - with pytest.raises(ValueError): - ω(a).at[()].set(ω(a2, is_leaf=lambda x: isinstance(x, int))) - - -def test_unvmap(): - unvmap_all = diffrax.misc.unvmap_all - unvmap_any = diffrax.misc.unvmap_any - jit_unvmap_all = jax.jit(unvmap_all) - jit_unvmap_any = jax.jit(unvmap_any) - vmap_unvmap_all = jax.vmap(unvmap_all, out_axes=None) - vmap_unvmap_any = jax.vmap(unvmap_any, out_axes=None) - - tt = jnp.array([True, True]) - tf = jnp.array([True, False]) - ff = jnp.array([False, False]) - - assert jnp.array_equal(unvmap_all(tt), jnp.array(True)) - assert jnp.array_equal(unvmap_all(tf), jnp.array(False)) - assert jnp.array_equal(unvmap_all(ff), jnp.array(False)) - assert jnp.array_equal(unvmap_any(tt), jnp.array(True)) - assert jnp.array_equal(unvmap_any(tf), jnp.array(True)) - assert jnp.array_equal(unvmap_any(ff), jnp.array(False)) - - assert jnp.array_equal(jit_unvmap_all(tt), jnp.array(True)) - assert jnp.array_equal(jit_unvmap_all(tf), jnp.array(False)) - assert jnp.array_equal(jit_unvmap_all(ff), jnp.array(False)) - assert jnp.array_equal(jit_unvmap_any(tt), jnp.array(True)) - assert jnp.array_equal(jit_unvmap_any(tf), jnp.array(True)) - assert jnp.array_equal(jit_unvmap_any(ff), jnp.array(False)) - - assert jnp.array_equal(vmap_unvmap_all(tt), jnp.array(True)) - assert jnp.array_equal(vmap_unvmap_all(tf), jnp.array(False)) - assert jnp.array_equal(vmap_unvmap_all(ff), jnp.array(False)) - assert jnp.array_equal(vmap_unvmap_any(tt), jnp.array(True)) - assert jnp.array_equal(vmap_unvmap_any(tf), jnp.array(True)) - assert jnp.array_equal(vmap_unvmap_any(ff), jnp.array(False)) - - unvmap_max = diffrax.misc.unvmap_max - jit_unvmap_max = jax.jit(unvmap_max) - vmap_unvmap_max = jax.vmap(unvmap_max, out_axes=None) - - _21 = jnp.array([2, 1]) - _11 = jnp.array([1, 1]) - - assert jnp.array_equal(unvmap_max(_21), jnp.array(2)) - assert jnp.array_equal(unvmap_max(_11), jnp.array(1)) - - assert jnp.array_equal(jit_unvmap_max(_21), jnp.array(2)) - assert jnp.array_equal(jit_unvmap_max(_11), jnp.array(1)) - - assert jnp.array_equal(vmap_unvmap_max(_21), jnp.array(2)) - assert jnp.array_equal(vmap_unvmap_max(_11), jnp.array(1)) - - -def test_nondifferentiable_input(): - ndi = lambda x: diffrax.misc.nondifferentiable_input(x, "name") - ndi( - jnp.array( - 2, - ) - ) # no error - with pytest.raises(ValueError): - jax.jvp(ndi, (jnp.array(2.0),), (jnp.array(1.0),)) - with pytest.raises(ValueError): - jax.grad(ndi)(jnp.array(2.0)) - - -def test_nondifferentiable_output(): - ndo = diffrax.misc.nondifferentiable_output - ndo(jnp.array(2.0)) # no error - jax.jvp(ndo, (jnp.array(2.0),), (jnp.array(1.0),)) # no error - with pytest.raises(RuntimeError): - jax.grad(ndo)(jnp.array(2.0)) diff --git a/test/test_vmap.py b/test/test_vmap.py index e70a4830..191d0b48 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp import jax.random as jrandom +import numpy as np import pytest @@ -93,7 +94,7 @@ def f(t, y, args): num_steps = sol.stats["num_steps"] if not isinstance(stepsize_controller, diffrax.ConstantStepSize): # not the same number of steps for every batch element - assert len(set(num_steps.to_py())) > 1 + assert len(set(np.asarray(num_steps))) > 1 assert jnp.array_equal(sol.t0, jnp.full((10,), t0)) assert jnp.array_equal(sol.t1, jnp.full((10,), t1)) assert sol.ts.shape == ( From ee39d820078a6f14b0a088c8efbbbb009406aa73 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 14 Nov 2022 22:56:23 -0800 Subject: [PATCH 5/5] Update adjoints.md --- docs/api/adjoints.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 41c5a742..39cef0d4 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -29,7 +29,7 @@ There are multiple ways to backpropagate through a differential equation (to com ::: diffrax.NoAdjoint selection: - members: false + members: false ::: diffrax.ImplicitAdjoint selection: