JAX: quickstart guide to high performance computation with python

Jul 31, 2022 • Written by Rene-Jean Corneille

Open In Colab

The library JAX built by Google research ports critical functionalities of tensorflow to a numpy-like API:

  • accelerated linear algebra
  • accelerators (GPU or TPU)
  • autodifferentiation
  • just-in-time compilation

Prototyping code with tensorflow can be hard. Since the library has evolved from being generic-purpose to being focused on neural network applications, JAX turns out to be a great alternative for research.

import jax
import numpy as np
import jax.numpy as jnp

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

JAX code can run on GPU or TPU. Access to GPU is already possible from cupy or numba. However, making TPU accessible is a clear advantage of JAX which powers the research echosystem at DeepMind.

JAX allows to select the default device (if an accelerator is detected it will be selected as the defaut device). However it's always better practice to explicitely state the choice.

jax.config.update('jax_platform_name', 'gpu')

Generating random numbers: JAX CPU vs NumPy

Here I perform a simple benchmark test: generating a 2d matrix of independant Gaussian draws of size N=10000N=10000.

jax.config.update('jax_platform_name', 'cpu')
N = 10000
key = jax.random.PRNGKey(29)
%%time
jax_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.5 s, sys: 484 ms, total: 2.99 s
Wall time: 1.65 s
%%time
numpy_array = np.random.normal(size=(N, N))
CPU times: user 4.41 s, sys: 42 ms, total: 4.46 s
Wall time: 4.43 s

JAX random number generation proves to be 2x as fast as NumPy on this very simple test.

Generating random numbers: JAX GPU vs NumPy

jax.config.update('jax_platform_name', 'gpu')
%%time
jax_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.35 s, sys: 482 ms, total: 2.84 s
Wall time: 1.3 s
%%time
numpy_array = np.random.normal(size=(N, N))
CPU times: user 4.39 s, sys: 54 ms, total: 4.45 s
Wall time: 4.4 s

Generating the random numbers with JAX on GPU prides a meaningful speed up. This motivated the emergence of JAX powered probabilistic programming libraries:

  • distrax: is a project from google with API similar to tensorflow probability
  • numpyro: is a loosely inspired pyro write up in JAX

Moreover, pymc3 and tensorflow probability also allow some functionalities to run on top of JAX.

Generating random numbers: JAX TPU vs NumPy

jax.config.update('jax_platform_name', 'tpu')
%%time
jax_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.5 s, sys: 10.9 ms, total: 2.51 s
Wall time: 1.05 s
%%time
numpy_array = np.random.normal(size=(N, N))
CPU times: user 4.41 s, sys: 21.9 ms, total: 4.43 s
Wall time: 4.39 s

Dot products: JAX CPU vs NumPy

There are several possible ways of performing a dot product in JAX

jax.config.update('jax_platform_name', 'cpu')
%%time
cpu_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.48 s, sys: 10.3 ms, total: 2.49 s
Wall time: 1.03 s
numpy_array = np.random.normal(size=(N, N))
%%time
_ = np.dot(numpy_array, numpy_array.T)
CPU times: user 58.8 s, sys: 3.6 s, total: 1min 2s
Wall time: 16.7 s
%%time
_ = jnp.dot(cpu_array, cpu_array.T)
CPU times: user 1min 2s, sys: 37.6 ms, total: 1min 2s
Wall time: 16.7 s

A square matrix multiplication is a computationally expensive operation. Here there is little difference between NumPy and a pure JAX matrix multiplication. This is another way that JAX is a powerful library: it leverages LAX to run large operations in batches and allow to run expensive (but parallelizeable) operations even faster:

%%time
_ = jnp.vdot(cpu_array, cpu_array.T)
CPU times: user 3.25 s, sys: 24 ms, total: 3.28 s
Wall time: 1.79 s

Even on CPU, with XLA (Accelerated Linear Algebra), JAX can provide a faster runtime on expensive linear algebra operations.

The vectorized operators in JAX (vdot, vmap) run the operation in batches by splitting the input.

Dot products: JAX GPU vs NumPy

jax.config.update('jax_platform_name', 'gpu')
%%time
gpu_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.91 s, sys: 14.8 ms, total: 2.92 s
Wall time: 1.34 s
numpy_array = np.random.normal(size=(N, N))
%%time
_ = np.dot(numpy_array, numpy_array.T)
CPU times: user 1min 3s, sys: 562 ms, total: 1min 4s
Wall time: 18.7 s

When using a jax.numpy dot ptoduct on a numpy array, it will be converted to a jax.numpy device array (and thus copied to the default device - in this case a GPU). So communication costs are going to be an important factor in the overall performance. It is possible to profile the execution of a jax.numpy operation to have a better view on the program resource use.

%%time
_ = jnp.dot(gpu_array, gpu_array.T).block_until_ready()
CPU times: user 1min 1s, sys: 800 ms, total: 1min 2s
Wall time: 19.1 s
%%time
_ = jnp.vdot(gpu_array, gpu_array.T).block_until_ready()
CPU times: user 1.63 s, sys: 40 ms, total: 1.67 s
Wall time: 906 ms

Dot products: JAX TPU vs NumPy

jax.config.update('jax_platform_name', 'tpu')
%%time
tpu_array = jax.random.normal(key, (N, N), dtype=np.float32)
CPU times: user 2.3 s, sys: 20 ms, total: 2.32 s
Wall time: 1.22 s
numpy_array = np.random.normal(size=(N, N))
%%time
_ = np.dot(numpy_array, numpy_array.T)
CPU times: user 1min 4s, sys: 62.8 ms, total: 1min 4s
Wall time: 18.7 s
%%time
_ = jnp.dot(tpu_array, tpu_array.T).block_until_ready()
CPU times: user 1min 3s, sys: 765 ms, total: 1min 4s
Wall time: 17.8 s
%%time
_ = jnp.vdot(tpu_array, tpu_array.T).block_until_ready()
CPU times: user 2.78 s, sys: 24 ms, total: 2.81 s
Wall time: 1.52 s

Copying data to a TPU is slow. Hence, when using TPU the execution time gain should offset the communication overhead otherwise, GPU is the better choice.

Results

results = pd.DataFrame.from_dict(
    {
        'random_numbers': [4400.0, 1650.0, 1300.0, 2510.0, np.nan, np.nan, np.nan],
        'dot_product': [64000.0, 64000.0, 62000.0, 64000.0, 3280.0, 1670.0, 2810.0]
    },
    orient='index',
    columns=['numpy', 'jax_cpu', 'jax_gpu', 'jax_tpu', 'jax_cpu_parallel', 'jax_gpu_parallel', 'jax_tpu_parallel']
)
fig, ax = plt.subplots(figsize=(20,10))
cmap = sns.blend_palette(['#319bc5', '#f75782'], n_colors=20, as_cmap=True)
sns.heatmap(results/1000, annot=True, fmt=".2f", 
           linewidths=5, cmap=cmap, 
           cbar_kws={"shrink": .8}, square=True, ax=ax)
plt.title('Task execution time in seconds')
Text(0.5, 1.0, 'Task execution time in seconds')

Where to go from here?

JAX is an exciting addition to the python libraries landscape. Historically, if you wanted to speed up critical parts of your python programs, the popular solutions where:

  • writing C/C++ bindings to execute expensive parts of your code on lower level languages
  • using libraries that are GPU-compatible with like CuPy or RAPIDS

JAX is an extra option which feels more natural as the code written can run with or without accelators with minimal changes. The main issue with the library is that the NumPy behaviour is not always guaranteed. However, the developers certainely are striving towards getting an API as close to NumPy as possible. The library is still experimental so we should hopefully see some improvement on this.


Subscribe

Get notified when I add new content.