GitHub - patrick-kidger/diffrax: Numerical differential equation solvers in JAX....
source link: https://github.com/patrick-kidger/diffrax
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Diffrax
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.
Diffrax is a JAX-based library providing numerical differential equation solvers.
Features include:
- ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
- lots of different solvers (including
Tsit5
,Dopri8
, symplectic solvers, implicit solvers); - vmappable everything (including the region of integration);
- using a PyTree as the state;
- dense solutions;
- multiple adjoint methods for backpropagation;
- support for neural differential equations.
From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.
Installation
pip install diffrax
Requires Python >=3.7 and JAX >=0.2.27.
Documentation
Available at https://docs.kidger.site/diffrax.
Quick example
from diffrax import diffeqsolve, ODETerm, Dopri5 import jax.numpy as jnp def f(t, y, args): return -y term = ODETerm(f) solver = Dopri5() y0 = jnp.array([2., 3.]) solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
Here, Dopri5
refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.
Citation
If you found this library useful in academic research, please cite: (arXiv link)
@phdthesis{kidger2021on, title={{O}n {N}eural {D}ifferential {E}quations}, author={Patrick Kidger}, year={2021}, school={University of Oxford}, }
(Also consider starring the project on GitHub.)
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK