Skip to content

Diffrax v0.6.0

Compare
Choose a tag to compare
@github-actions github-actions released this 01 Jul 09:16
· 44 commits to main since this release

Features

  • Continuous events! It is now possible to specify a condition at which point the differential equation should halt. For example, here's one finding the time at which a dropped ball hits the ground:

    import diffrax
    import jax.numpy as jnp
    import optimistix as optx
    
    def vector_field(t, y, args):
        _, v = y
        return jnp.array([v, -9.81])
    
    def cond_fn(t, y, args, **kwargs):
        x, _ = y
        return x
    
    term = diffrax.ODETerm(vector_field)
    solver = diffrax.Tsit5()
    t0 = 0
    t1 = jnp.inf
    dt0 = 0.1
    y0 = jnp.array([10.0, 0.0])
    root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
    event = diffrax.Event(cond_fn, root_finder)
    sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
    print(f"Event time: {sol.ts[0]}") # Event time: 1.42...
    print(f"Velocity at event time: {sol.ys[0, 1]}") # Velocity at event time: -14.00...

    When cond_fn hits zero, the solve stops. Once the event changes sign, then we use Optimistix to do a root find to locate the exact time at which the equation should terminate. Event handling is also fully differentiable.

    Getting this in was a huge amount of work from @cholberg -- thank you! -- and has been one of our longest-requested features for a while, so I'm really happy to have this in.

    (We previously only had 'discrete events', which just terminated at the end of a step, and did not do a root find.)

    See the events page in the documentation for more.

  • Simulation of space-time-time Lévy area. This is a higher-order statistic of Brownian motion, used in some advanced SDE solvers. We don't have any such solvers yet, but watch this space... ;)

    This was a hugely impressive technical effort from @andyElking. Check out our arXiv paper on the topic, which discusses the technical nitty-gritty of how these statistics can be simulated in an efficient manner.

  • ControlTerm now supports returning a Lineax linear operator. For example, here's how to easily create a diagonal diffusion term:

    def vector_field(t, y, args):
        # y is a JAX array of shape (2,)
        y1, y2 = y
        diagonal = jnp.array([y2, y1])
        return lineax.DiagonalLinearOperator(diagonal)  # corresponds to the matrix [[y2, 0], [0, y1]]
    
    diffusion_term = ControlTerm(vector_field, ...)

    This should make it much easier to express SDEs with particular structure to their diffusion matrices.

    This is particularly good for efficiency reasons: the operator-specified .mv (matrix-vector product) method is used, which typically provides a more efficient implementation than that given by filling in some zeros and using a dense matrix-vector product.

    Thank you to @lockwo for implementing this one!

    See the documentation on ControlTerm for more.

Deprecations

Two APIs have now been deprecated.

Both of these APIs now have compatibility layers, so existing code should continue to work. However, they will now emit deprecation warnings, and users are encouraged to upgrade. These APIs may be removed at a later date.

  • diffeqsolve(..., discrete_terminating_event=...), along with the corresponding classes AbstractDiscreteTerminatingEvent + DiscreteTerminatingEvent + SteadyStateEvent. These have been superseded by diffeqsolve(..., event=Event(...)).

  • WeaklyDiagonalControlTerm has been superseded by the new behaviour for ControlTerm, and its interaction with Lineax, as discussed above.

Other

  • Now working around an upstream bug introduced in JAX 0.4.29+, so we should be compatible with modern JAX releases.
  • No longer emitting warnings coming from JAX deprecating a few old APIs. (We've migrated to the new ones.)

Full Changelog: v0.5.1...v0.6.0