JAX Ecosystem

Grigory Sapunov
6 min readDec 20, 2020

--

JAX by Google Research is getting more and more popular.

UPDATE from 2023. I am almost finished writing a book, “Deep Learning with JAX” with Manning. It is already available as an early release (MEAP) and contains a lot of useful and up-to-date knowledge about JAX.

Deepmind recently announced they are using JAX to accelerate their research and already developed a set of libraries on top of JAX.

There are more and more research papers built using JAX, say, the recent Vision Transformer (ViT).

JAX demonstrates impressive performance. And even the world’s fastest transformer is built on JAX now.

The recent JAX success is the GPT-J-6B model by EleutherAI. According to the authors: “This project required a substantially smaller amount of person-hours than other large-scale model developments did, which demonstrates that JAX + xmap + TPUs is the right set of tools for quick development of large-scale models.”

What is JAX?

And what are its strengths?

JAX is a pretty low-level library similar to NumPy but with several cool features:

  • Autograd — JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily to any order. The good old autograd library is a predecessor of JAX; some of the autograd authors are now developing JAX. And, BTW, if you seek a better understanding of what automatic differentiation is, how it differs from numerical or symbolic differentiation, what the difference is between forward and reverse mode, and so on, then read this good survey on the topic.
  • Compiling your code — JAX uses XLA to compile and run your NumPy code on accelerators like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time (JIT) compiled and executed.
  • JIT — JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
  • Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed and provided that the Python code does not actually need to inspect the output of computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
  • Auto-vectorization with vmap, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with JIT, it can be just as fast as adding the batch dimensions by hand.
  • Parallelization with pmap. The purpose of pmap() is to express single-program multiple-data (SPMD) programs. Applying pmap() to a function will compile the function with XLA (similarly to jit()), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically, it is comparable to vmap() because both transformations map a function over array axes, but where vmap() vectorizes functions by pushing the mapped axis down into primitive operations, pmap() instead replicates the function and executes each replica on its own XLA device in parallel.
  • Named-axis programming model with xmap (experimental). This helps you write error-avoiding, self-documenting functions using named axes, then control how they’re executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer. The killer feature of xmap is its ability to parallelize code over supercomputer-scale hardware meshes!

Here is a short intro to JAX (7 min):

And a deeper introduction (23 min):

And an even deeper one (45 min):

And the deepest one from one of the authors of JAX (1h 10m):

And a couple of articles on JAX:

The library looks very promising, so keep an eye on it.

Libraries on top of JAX

There are several high-level libraries for neural networks on top of JAX:

  • Flax (Google Brain) — a high-level neural network library designed for flexibility (compare to Keras, Sonnet, or Haiku).
    Repo: https://github.com/google/flax
  • Haiku (Deepmind) — a JAX-based neural network library, aka Sonnet for JAX (if you don’t know what is Sonnet, it’s similar to Keras high-level neural network library on top of TensorFlow).
    Repo: https://github.com/deepmind/dm-haiku
  • Trax (Google Brain) — an end-to-end library for deep learning that focuses on clear code and speed. Trax includes basic models (like ResNet, LSTM, Transformer, and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.
    Repo: https://github.com/google/trax
  • Objax (Google) — a minimalist object-oriented framework with a PyTorch-like interface. Its name comes from the contraction of Object and JAX. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.
    Repo: https://github.com/google/objax
  • Stax (experimental part of JAX) — a small but flexible neural net specification library from scratch.
    Repo: https://jax.readthedocs.io/en/latest/jax.experimental.stax.html
  • Elegy (PoetsAI) — a Neural Networks framework based on Jax and inspired by Keras. Elegy implements the Keras API but makes changes to play better with Jax and gives more flexibility around losses and metrics and module system that makes it super easy to use.
    Repo: https://github.com/poets-ai/elegy
  • Mesh Transformer JAX — The framework which was used to build the recent GPT-J-6B language model. It provides model parallel transformers in JAX and Haiku. The is designed for scalability up to approximately 20B parameters on TPUv3s, beyond which different parallelism strategies should be used.
    Repo: https://github.com/kingoflolz/mesh-transformer-jax
  • Swarm JAX — Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes.
    Repo: https://github.com/kingoflolz/swarm-jax

A couple of RL libraries (but keep in mind some high-level libraries, say, Trax, have RL features too):

  • RLax (Deepmind) — building blocks for implementing reinforcement learning agents. The components in RLax cover a broad spectrum of algorithms and ideas: TD-learning, policy gradients, actor critics, MAP, proximal policy optimization, non-linear value transformation, general value functions, and a number of exploration methods.
    Repo: https://github.com/deepmind/rlax
  • Coax (Microsoft) — a modular Reinforcement Learning (RL) python package for solving OpenAI Gym environments with JAX-based function approximations.
    Repo: https://github.com/microsoft/coax

Some other more specialized libraries:

  • Optax (Deepmind) — a gradient processing and optimization library for JAX. Optax provides a library of gradient transformations, together with composition operators (e.g. chain) that allow implementing many standard optimizers (e.g. RMSProp or Adam) in just a single line of code.
    Repo: https://github.com/deepmind/optax
  • Chex (Deepmind) — a library of utilities for helping to write reliable JAX code. Chex provides an assortment of utilities including JAX-aware unit testing, assertions of properties of JAX datatypes, mocks and fakes, and multi-device test environments.
    Repo: https://github.com/deepmind/chex
  • Jraph, pronounced “giraffe”, (Deepmind) — a lightweight library for working with graph neural networks in JAX. Jraph provides a standardized data structure for graphs, a set of utilities for working with graphs, and a ‘zoo’ of easily forkable and extensible graph neural network models.
    Repo: https://github.com/deepmind/jraph
  • JAX, M.D. (Google) — A Framework for Differentiable Physics. It provides differentiable, hardware accelerated, molecular dynamics build on top of JAX.
    Repo: https://github.com/google/jax-md
  • Oryx (Google?) is a library for probabilistic programming and deep learning built on top of Jax. Oryx is an experimental library that extends JAX to applications ranging from building and training complex neural networks to approximate Bayesian inference in deep generative models. Like JAX provides jit, vmap, and grad, Oryx provides a set of composable function transformations that enable writing simple code and transforming it to build complexity while staying completely interoperable with JAX.
    Repo: https://github.com/tensorflow/probability/tree/master/spinoffs/oryx

Release Notes

2023/10/05 Added a note about the “Deep Learning with JAX” book.

2021/06/17 Added Mesh Transformer JAX and Swarm JAX. Added Named-axis programming model with xmap.

2020/12/20 Original article published

--

--

Grigory Sapunov

ML/DL/AI expert. Software engineer with 20+ years programming experience. Loves Life Sciences. CTO and co-Founder of Intento. Google Developer Expert in ML.