Skip to content
JAX Talk: Diffrax

Details

In this session, we'll discuss “Diffrax”, a JAX-based suite of ordinary/stochastic/controlled differential equation solvers, roughly analogous to the existing torchdiffeq/torchsde/DifferentialEquations.jl packages, which are available in the PyTorch and JAX ecosystems.

Highlights include:

  • High performance, e.g. 200x speedup over torchdiffeq and with similar performance to DifferentialEquations.jl.
  • Numerous features: high-order solvers, implicit solvers, dense solutions, multiple adjoints methods, etc.
  • Integrates directly with JAX: jit/grad/vmap/etc. all work as normal.
  • Easily extensible with custom ops (solvers etc.); includes the ability to handle the stepping yourself if writing a differentiable simulator yourself.
  • For the numerical analyst: ODEs/SDEs/etc. are all solved in a unified way.
Photo of JAX Global Meetup group
JAX Global Meetup
See more events
Online event
This event has passed