JAX Talk: Diffrax

Hosted By
Sanyam B. and 2 others

Details
In this session, we'll discuss “Diffrax”, a JAX-based suite of ordinary/stochastic/controlled differential equation solvers, roughly analogous to the existing torchdiffeq/torchsde/DifferentialEquations.jl packages, which are available in the PyTorch and JAX ecosystems.
Highlights include:
- High performance, e.g. 200x speedup over torchdiffeq and with similar performance to DifferentialEquations.jl.
- Numerous features: high-order solvers, implicit solvers, dense solutions, multiple adjoints methods, etc.
- Integrates directly with JAX: jit/grad/vmap/etc. all work as normal.
- Easily extensible with custom ops (solvers etc.); includes the ability to handle the stepping yourself if writing a differentiable simulator yourself.
- For the numerical analyst: ODEs/SDEs/etc. are all solved in a unified way.

JAX Global Meetup
See more events
Online event
This event has passed
JAX Talk: Diffrax