JAX by Google Research is getting more and more popular.
Deepmind recently announced they are using JAX to accelerate their research and already developed a set of libraries on top of JAX.
There are more and more research papers build using JAX, say, the recent Vision Transformer (ViT).
What is JAX?
And what are its strengths?
JAX is a pretty low-level library similar to NumPy but with several cool features:
- Autograd — JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily to any order. The good old autograd library is a predecessor of JAX, some of the autograd authors are now developing JAX. And, BTW, if you seek a better understanding of what is automatic differentiation, how it differs from numerical or symbolic differentiation, what is the difference between forward and reverse mode, and so on, then read this good survey on the topic.
- Compiling your code — JAX uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time (JIT) compiled and executed.
- JIT — JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
- Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and provided that the Python code does not actually need to inspect the output of computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
- Auto-vectorization with vmap, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with JIT, it can be just as fast as adding the batch dimensions by hand.
- Parallelization with pmap. The purpose of
pmap()is to express single-program multiple-data (SPMD) programs. Applying
pmap()to a function will compile the function with XLA (similarly to
jit()), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable to
vmap()because both transformations map a function over array axes, but where
vmap()vectorizes functions by pushing the mapped axis down into primitive operations,
pmap()instead replicates the function and executes each replica on its own XLA device in parallel.
Here is a short intro to JAX (7 min):
And a deeper introduction (23 min):
And an even deeper one (45 min):
GTC 2020: JAX: Accelerating Machine-Learning Research with Composable Function Transformations in…
JAX is a system for high-performance machine-learning research. It offers the familiarity of Python+NumPy together with…
And the deepest one from one of the authors of JAX (1h 10m):
And a couple of articles on JAX:
- A good practical intro by Simone Scardapane: “JAX, AKA NUMPY ON STEROIDS”.
- “From PyTorch to JAX: towards neural net frameworks that purify stateful code” is a kind of practical comparison of JAX and PyTorch approaches.
The library looks very promising, so keep an eye on it.
Libraries on top of JAX
There are several high-level libraries for neural networks on top of JAX:
- Flax (Google Brain) — a high-level neural network library designed for flexibility (compare to Keras, Sonnet, or Haiku).
- Haiku (Deepmind) — a JAX-based neural network library, aka Sonnet for JAX (if you don’t know what is Sonnet, it’s similar to Keras high-level neural network library on top of TensorFlow).
- Trax (Google Brain) — an end-to-end library for deep learning that focuses on clear code and speed. Trax includes basic models (like ResNet, LSTM, Transformer, and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.
- Objax (Google) — a minimalist object-oriented framework with a PyTorch-like interface. Its name comes from the contraction of Object and JAX. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.
- Stax (experimental part of JAX) — a small but flexible neural net specification library from scratch.
- Elegy (PoetsAI) — a Neural Networks framework based on Jax and inspired by Keras. Elegy implements the Keras API but makes changes to play better with Jax and gives more flexibility around losses and metrics and module system that makes it super easy to use.
A couple of RL libraries (but keep in mind some high-level libraries, say, Trax, have RL features too):
- RLax (Deepmind) — building blocks for implementing reinforcement learning agents. The components in RLax cover a broad spectrum of algorithms and ideas: TD-learning, policy gradients, actor critics, MAP, proximal policy optimization, non-linear value transformation, general value functions, and a number of exploration methods.
- Coax (Microsoft) — a modular Reinforcement Learning (RL) python package for solving OpenAI Gym environments with JAX-based function approximations.
Some other more specialized libraries:
- Optax (Deepmind) — a gradient processing and optimization library for JAX. Optax provides a library of gradient transformations, together with composition operators (e.g.
chain) that allow implementing many standard optimizers (e.g. RMSProp or Adam) in just a single line of code.
- Chex (Deepmind) — a library of utilities for helping to write reliable JAX code. Chex provides an assortment of utilities including JAX-aware unit testing, assertions of properties of JAX datatypes, mocks and fakes, and multi-device test environments.
- Jraph, pronounced “giraffe”, (Deepmind) — a lightweight library for working with graph neural networks in JAX. Jraph provides a standardized data structure for graphs, a set of utilities for working with graphs, and a ‘zoo’ of easily forkable and extensible graph neural network models.
- JAX, M.D. (Google) — A Framework for Differentiable Physics. It provides differentiable, hardware accelerated, molecular dynamics build on top of JAX.
- Oryx (Google?) is a library for probabilistic programming and deep learning built on top of Jax. Oryx is an experimental library that extends JAX to applications ranging from building and training complex neural networks to approximate Bayesian inference in deep generative models. Like JAX provides
grad, Oryx provides a set of composable function transformations that enable writing simple code and transforming it to build complexity while staying completely interoperable with JAX.