Jul 31, 2022 • Written by Rene-Jean Corneille
The library JAX built by Google research ports critical functionalities of tensorflow to a numpy-like API:
Prototyping code with tensorflow can be hard. Since the library has evolved from being generic-purpose to being focused on neural network applications, JAX turns out to be a great alternative for research.
import jax
import numpy as np
import jax.numpy as jnp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
JAX code can run on GPU or TPU. Access to GPU is already possible from cupy or numba. However, making TPU accessible is a clear advantage of JAX which powers the research echosystem at DeepMind.
JAX allows to select the default device (if an accelerator is detected it will be selected as the defaut device). However it's always better practice to explicitely state the choice.
jax.config.update('jax_platform_name', 'gpu')
Here I perform a simple benchmark test: generating a 2d matrix of independant Gaussian draws of size .
jax.config.update('jax_platform_name', 'cpu')
N = 10000
key = jax.random.PRNGKey(29)
%%time
jax_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.5 s, sys: 484 ms, total: 2.99 s Wall time: 1.65 s
%%time
numpy_array = np.random.normal(size=(N, N))
CPU times: user 4.41 s, sys: 42 ms, total: 4.46 s Wall time: 4.43 s
JAX random number generation proves to be 2x as fast as NumPy on this very simple test.
jax.config.update('jax_platform_name', 'gpu')
%%time
jax_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.35 s, sys: 482 ms, total: 2.84 s Wall time: 1.3 s
%%time
numpy_array = np.random.normal(size=(N, N))
CPU times: user 4.39 s, sys: 54 ms, total: 4.45 s Wall time: 4.4 s
Generating the random numbers with JAX on GPU prides a meaningful speed up. This motivated the emergence of JAX powered probabilistic programming libraries:
Moreover, pymc3 and tensorflow probability also allow some functionalities to run on top of JAX.
jax.config.update('jax_platform_name', 'tpu')
%%time
jax_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.5 s, sys: 10.9 ms, total: 2.51 s Wall time: 1.05 s
%%time
numpy_array = np.random.normal(size=(N, N))
CPU times: user 4.41 s, sys: 21.9 ms, total: 4.43 s Wall time: 4.39 s
There are several possible ways of performing a dot product in JAX
jax.config.update('jax_platform_name', 'cpu')
%%time
cpu_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.48 s, sys: 10.3 ms, total: 2.49 s Wall time: 1.03 s
numpy_array = np.random.normal(size=(N, N))
%%time
_ = np.dot(numpy_array, numpy_array.T)
CPU times: user 58.8 s, sys: 3.6 s, total: 1min 2s Wall time: 16.7 s
%%time
_ = jnp.dot(cpu_array, cpu_array.T)
CPU times: user 1min 2s, sys: 37.6 ms, total: 1min 2s Wall time: 16.7 s
A square matrix multiplication is a computationally expensive operation. Here there is little difference between NumPy and a pure JAX matrix multiplication. This is another way that JAX is a powerful library: it leverages LAX to run large operations in batches and allow to run expensive (but parallelizeable) operations even faster:
%%time
_ = jnp.vdot(cpu_array, cpu_array.T)
CPU times: user 3.25 s, sys: 24 ms, total: 3.28 s Wall time: 1.79 s
Even on CPU, with XLA (Accelerated Linear Algebra), JAX can provide a faster runtime on expensive linear algebra operations.
The vectorized operators in JAX (vdot, vmap) run the operation in batches by splitting the input.
jax.config.update('jax_platform_name', 'gpu')
%%time
gpu_array = jax.random.normal(key, (N, N), dtype=np.float32).block_until_ready()
CPU times: user 2.91 s, sys: 14.8 ms, total: 2.92 s Wall time: 1.34 s
numpy_array = np.random.normal(size=(N, N))
%%time
_ = np.dot(numpy_array, numpy_array.T)
CPU times: user 1min 3s, sys: 562 ms, total: 1min 4s Wall time: 18.7 s
When using a jax.numpy
dot ptoduct on a numpy
array, it will be converted to a jax.numpy
device array (and thus copied to the default device - in this case a GPU). So communication costs are going to be an important factor in the overall performance. It is possible to profile the execution of a jax.numpy
operation to have a better view on the program resource use.
%%time
_ = jnp.dot(gpu_array, gpu_array.T).block_until_ready()
CPU times: user 1min 1s, sys: 800 ms, total: 1min 2s Wall time: 19.1 s
%%time
_ = jnp.vdot(gpu_array, gpu_array.T).block_until_ready()
CPU times: user 1.63 s, sys: 40 ms, total: 1.67 s Wall time: 906 ms
jax.config.update('jax_platform_name', 'tpu')
%%time
tpu_array = jax.random.normal(key, (N, N), dtype=np.float32)
CPU times: user 2.3 s, sys: 20 ms, total: 2.32 s Wall time: 1.22 s
numpy_array = np.random.normal(size=(N, N))
%%time
_ = np.dot(numpy_array, numpy_array.T)
CPU times: user 1min 4s, sys: 62.8 ms, total: 1min 4s Wall time: 18.7 s
%%time
_ = jnp.dot(tpu_array, tpu_array.T).block_until_ready()
CPU times: user 1min 3s, sys: 765 ms, total: 1min 4s Wall time: 17.8 s
%%time
_ = jnp.vdot(tpu_array, tpu_array.T).block_until_ready()
CPU times: user 2.78 s, sys: 24 ms, total: 2.81 s Wall time: 1.52 s
Copying data to a TPU is slow. Hence, when using TPU the execution time gain should offset the communication overhead otherwise, GPU is the better choice.
results = pd.DataFrame.from_dict(
{
'random_numbers': [4400.0, 1650.0, 1300.0, 2510.0, np.nan, np.nan, np.nan],
'dot_product': [64000.0, 64000.0, 62000.0, 64000.0, 3280.0, 1670.0, 2810.0]
},
orient='index',
columns=['numpy', 'jax_cpu', 'jax_gpu', 'jax_tpu', 'jax_cpu_parallel', 'jax_gpu_parallel', 'jax_tpu_parallel']
)
fig, ax = plt.subplots(figsize=(20,10))
cmap = sns.blend_palette(['#319bc5', '#f75782'], n_colors=20, as_cmap=True)
sns.heatmap(results/1000, annot=True, fmt=".2f",
linewidths=5, cmap=cmap,
cbar_kws={"shrink": .8}, square=True, ax=ax)
plt.title('Task execution time in seconds')
Text(0.5, 1.0, 'Task execution time in seconds')
JAX is an exciting addition to the python libraries landscape. Historically, if you wanted to speed up critical parts of your python programs, the popular solutions where:
JAX is an extra option which feels more natural as the code written can run with or without accelators with minimal changes. The main issue with the library is that the NumPy behaviour is not always guaranteed. However, the developers certainely are striving towards getting an API as close to NumPy as possible. The library is still experimental so we should hopefully see some improvement on this.