JAX brings Autograd and XLA together for high-performance machine learning research. It can automatically differentiate native Python and NumPy functions. The code can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. JAX supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
Please see citation information here: https://github.com/google/jax#citing-jax