Skip to content

Commit

Permalink
Merge pull request #190 from patrick-kidger/internal4
Browse files Browse the repository at this point in the history
Upgrade to `equinox.internal`
  • Loading branch information
patrick-kidger authored Nov 15, 2022
2 parents b847552 + ee39d82 commit ea1bdc9
Show file tree
Hide file tree
Showing 45 changed files with 637 additions and 933 deletions.
80 changes: 80 additions & 0 deletions benchmarks/compile_times.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import functools as ft
import timeit

import diffrax as dfx
import equinox as eqx
import fire
import jax
import jax.numpy as jnp
import jax.random as jr


def _weight(in_, out, key):
return [[w_ij for w_ij in w_i] for w_i in jr.normal(key, (out, in_))]


class VectorField(eqx.Module):
weights: list

def __init__(self, in_, out, width, depth, *, key):
keys = jr.split(key, depth + 1)
self.weights = [_weight(in_, width, keys[0])]
for i in range(1, depth):
self.weights.append(_weight(width, width, keys[i]))
self.weights.append(_weight(width, out, keys[depth]))

def __call__(self, t, y, args):
# Inefficient computation graph to make a toy example more expensive.
y = [y_i for y_i in y]
for w in self.weights:
y = [sum(w_ij * y_j for w_ij, y_j in zip(w_i, y)) for w_i in w]
return jnp.stack(y)


def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str):
if adjoint == "direct":
adjoint = dfx.DirectAdjoint()
elif adjoint == "recursive":
adjoint = dfx.RecursiveCheckpointAdjoint()
elif adjoint == "backsolve":
adjoint = dfx.BacksolveAdjoint()
else:
raise ValueError
if grad:
grad_decorator = jax.grad
else:
grad_decorator = lambda x: x

vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0))
if not inline:
vf = eqx.internal.noinline(vf)
term = dfx.ODETerm(vf)
solver = dfx.Dopri8(scan_stages=scan_stages)
stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-6)
t0 = 0
t1 = 1
dt0 = 0.01

@jax.jit
@grad_decorator
def solve(y0):
sol = dfx.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
adjoint=adjoint,
max_steps=16**2,
)
return jnp.sum(sol.ys)

solve_ = ft.partial(solve, jnp.array([1.0]))
print("Compile+run time", timeit.timeit(solve_, number=1))
print("Run time", timeit.timeit(solve_, number=1))


if __name__ == "__main__":
fire.Fire(main)
9 changes: 5 additions & 4 deletions benchmarks/small_neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import torch
import torchdiffeq

Expand Down Expand Up @@ -173,10 +174,10 @@ def main(batch_size=64, t1=100, multiple=False, grad=False):
with torch.no_grad():
func_jax = neural_ode_diffrax.func.func
func_torch = neural_ode_torch.func.func
func_torch[0].weight.copy_(torch.tensor(func_jax.layers[0].weight.to_py()))
func_torch[0].bias.copy_(torch.tensor(func_jax.layers[0].bias.to_py()))
func_torch[2].weight.copy_(torch.tensor(func_jax.layers[1].weight.to_py()))
func_torch[2].bias.copy_(torch.tensor(func_jax.layers[1].bias.to_py()))
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))

y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4))
y0_torch = torch.tensor(y0_jax.to_py())
Expand Down
2 changes: 1 addition & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@
)


__version__ = "0.2.1"
__version__ = "0.2.2"
139 changes: 108 additions & 31 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,67 @@
from typing import Any, Dict

import equinox as eqx
import equinox.internal as eqxi
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω

from .misc import implicit_jvp, nondifferentiable_output, ω
from .misc import implicit_jvp
from .saveat import SaveAt
from .term import AbstractTerm, AdjointTerm


def _is_none(x):
return x is None


def _no_transpose_final_state(final_state):
y = eqxi.nondifferentiable_backward(final_state.y, name="y")
tprev = eqxi.nondifferentiable_backward(final_state.tprev, name="tprev")
tnext = eqxi.nondifferentiable_backward(final_state.tnext, name="tnext")
solver_state = eqxi.nondifferentiable_backward(
final_state.solver_state, name="solver_state"
)
controller_state = eqxi.nondifferentiable_backward(
final_state.controller_state, name="controller_state"
)
ts = eqxi.nondifferentiable_backward(final_state.ts, name="ts")
ys = final_state.ys
dense_ts = eqxi.nondifferentiable_backward(final_state.dense_ts, name="dense_ts")
dense_infos = eqxi.nondifferentiable_backward(
final_state.dense_infos, name="dense_infos"
)
final_state = eqxi.nondifferentiable_backward(final_state) # no more specific name
final_state = eqx.tree_at(
lambda s: (
s.y,
s.tprev,
s.tnext,
s.solver_state,
s.controller_state,
s.ts,
s.ys,
s.dense_ts,
s.dense_infos,
),
final_state,
(
y,
tprev,
tnext,
solver_state,
controller_state,
ts,
ys,
dense_ts,
dense_infos,
),
is_leaf=_is_none,
)
return final_state


class AbstractAdjoint(eqx.Module):
"""Abstract base class for all adjoint methods."""

Expand All @@ -30,6 +82,8 @@ def loop(
max_steps,
throw,
init_state,
passed_solver_state,
passed_controller_state,
):
"""Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
Expand Down Expand Up @@ -69,27 +123,26 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint):
For most problems this is the preferred technique for backpropagating through a
differential equation.
A binomial checkpointing scheme is used so that memory usage is low.
In addition a binomial checkpointing scheme is used so that memory usage is low.
(This checkpointing can increase compile time a bit, though.)
"""

def loop(self, *, throw, **kwargs):
del throw
def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs):
del throw, passed_solver_state, passed_controller_state
return self._loop_fn(**kwargs, is_bounded=True)


class NoAdjoint(AbstractAdjoint):
"""Disable backpropagation through [`diffrax.diffeqsolve`][].
Forward-mode autodifferentiation (`jax.jvp`) will continue to work as normal.
If you do not need to differentiate the results of [`diffrax.diffeqsolve`][] then
this may sometimes improve the speed at which the differential equation is solved.
"""

def loop(self, *, throw, **kwargs):
del throw
def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs):
del throw, passed_solver_state, passed_controller_state
final_state, aux_stats = self._loop_fn(**kwargs, is_bounded=False)
final_state = jtu.tree_map(nondifferentiable_output, final_state)
final_state = eqxi.nondifferentiable_backward(final_state)
return final_state, aux_stats


Expand Down Expand Up @@ -135,7 +188,19 @@ class ImplicitAdjoint(AbstractAdjoint):
via the implicit function theorem.
""" # noqa: E501

def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
def loop(
self,
*,
args,
terms,
solver,
saveat,
throw,
init_state,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw

# `is` check because this may return a Tracer from SaveAt(ts=<array>)
Expand All @@ -144,21 +209,30 @@ def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
"Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`."
)

init_state = eqx.tree_at(
lambda s: (s.y, s.solver_state, s.controller_state),
init_state,
replace_fn=lax.stop_gradient,
)
if not passed_solver_state:
init_state = eqx.tree_at(
lambda s: s.solver_state,
init_state,
replace_fn=lax.stop_gradient,
is_leaf=_is_none,
)
if not passed_controller_state:
init_state = eqx.tree_at(
lambda s: s.controller_state,
init_state,
replace_fn=lax.stop_gradient,
is_leaf=_is_none,
)

closure = (self, kwargs, solver, saveat, init_state)
ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure)

final_state_no_ys, aux_stats = residual
return (
eqx.tree_at(
lambda s: s.ys, final_state_no_ys, ys, is_leaf=lambda x: x is None
),
aux_stats,
final_state = eqx.tree_at(
lambda s: s.ys, final_state_no_ys, ys, is_leaf=_is_none
)
final_state = _no_transpose_final_state(final_state)
return final_state, aux_stats


# Compute derivatives with respect to the first argument:
Expand All @@ -174,7 +248,7 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs):
)
del y
return self._loop_fn(
args=args, terms=terms, init_state=init_state, **kwargs, is_bounded=False
args=args, terms=terms, init_state=init_state, is_bounded=False, **kwargs
)


Expand Down Expand Up @@ -398,7 +472,18 @@ def __init__(self, **kwargs):
)
self.kwargs = kwargs

def loop(self, *, args, terms, saveat, init_state, **kwargs):
def loop(
self,
*,
args,
terms,
saveat,
init_state,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del passed_solver_state, passed_controller_state
if saveat.steps or saveat.dense:
raise NotImplementedError(
"Cannot use `adjoint=BacksolveAdjoint()` with "
Expand All @@ -414,13 +499,5 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
final_state, aux_stats = _loop_backsolve(
(y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs
)

# We only allow backpropagation through `ys`; in particular not through
# `solver_state` etc.
ys = final_state.ys
final_state = jtu.tree_map(nondifferentiable_output, final_state)
final_state = eqx.tree_at(
lambda s: jtu.tree_leaves(s.ys), final_state, jtu.tree_leaves(ys)
)

final_state = _no_transpose_final_state(final_state)
return final_state, aux_stats
6 changes: 3 additions & 3 deletions diffrax/brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import abc

from ..custom_types import Array, Scalar
from ..custom_types import Array, PyTree, Scalar
from ..path import AbstractPath


class AbstractBrownianPath(AbstractPath):
"Abstract base class for all Brownian paths."

@abc.abstractmethod
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array:
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.
Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand All @@ -23,7 +23,7 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> Array:
**Returns:**
A JAX array corresponding to the increment $w(t_1) - w(t_0)$.
A pytree of JAX arrays corresponding to the increment $w(t_1) - w(t_0)$.
Some subclasses may allow `t1=None`, in which case just the value $w(t_0)$ is
returned.
Expand Down
Loading

0 comments on commit ea1bdc9

Please sign in to comment.