I decided to do this for two reasons. The first reason is that, for years, I had to bear my Ph.D. advisor coming into the lab while I was happily coding my Pytorch model, slowly sneaking at my back, stare at my screen and say - with a disappointed look - “you should definitely do this in JAX”. The second reason is this nice blog post from Neel Gupta.
However, every time I tried to use JAX, I ended up using Flax instead, which offers a kind of object oriented interface (similar to torch). While Flax is great, it introduces additional layers of abstraction that make it similar to Pytorch and therefore I ended up wondering: “why am I doing this?”. There are other great frameworks as well, with different functionalities, like equinox (maybe closer to JAX’s original nature), but they always add “another layer”.
This time, I wanted to take a taste of bare JAX and avoid external libraries or abstractions. In this implementation, I’ve built a basic Vision Transformer entirely from scratch. Although it may not be the most efficient code, my focus is to explore JAX directly and train a small model while leveraging JAX’s core features, like vmap
and jit
, without any external frameworks.
I will cover the following topics:
- Initialization of the weights (in pure JAX it can take a while)
- Coding the ViT logic and parallelization (with
jax.vmap
) - Training with just in time (with
jax.jit
)
✋ If you are not interested in model initialization, you can just skip to the core part where we implement the model and train it.
Vision TransfomerPermalink
In the following, I assume you are already familiar with the Vision Transformer architecture. If you are not, you can take a look here. In short, ViTs split images into patches, and treat patches as tokens (like words in NLP models), processing them using transformer layers with bidirectional (non masked) attention. In this post, we’ll build a small ViT that can train on the Imagenette dataset, and you can even run it on your local machine.
Speaking of GPUs, JAX offers seamless handling of hardware acceleration. It automatically detects and utilizes available GPUs/TPUs without requiring explicit code changes.
In this notebook, we are going to use a small ViT, with the following hyperparameters:
Initializing the modelPermalink
JAX is a fully functional framework, which means that model parameters are treated as a distinct set of numbers, existing “outside” the model itself. This gives you a nice, low-level feel for how the model works. Instead of encapsulating parameters within an object (like in torch), you’re directly manipulating a concrete set of weights along with a function that processes them.
To initialize these weights at random, we need some random primitives (just like in torch). In JAX, every call to a random primitive requires a random key, which ensures that the randomness is both explicit and controllable. This means that instead of going
you have to explicitly allocate a key first and then use it to generate a random number, like this:
This is great for ML practitioners, and you know what I’m talking about if you ever had to use torch random seeding and ended up with reproducibility issues. The main reason for JAX explicitly tracking the random keys without using a global random state is that this would compromise the execution of parallel code, that is one of the main perks of JAX. You can read more about randomness in JAX here.
Let’s see what parameters we need for our ViT. Here is the list:
- a
CLS
(classification) token - a projection to transform the patches into tokens
- a positional encoding
- N transformer blocks made of multihead attention and Feed Forward MLP
- a final head for classification
As mentioned before, these parameters are just numbers that we can store in a dictionary like this:
We now need to initialize each set of weights separately. Again, we could use a library for this, like optax, but we want to go through the process manually to better understand what’s happening under the hood.
We initialize the class token with all zeros:
For patch embedding, positional encoding, final head, and all transformer blocks we use random values (check the colab for complete code). Each transformer block is made up of attention
, mlp
and layer normalization
. We define a function to initialize each of these components. I’ll show just the mlp initialization here for brevity. I’ll do it using Xavier intialization, but this is not crucial and you can just use a random normal.
For the MLP, we need weights and biases for 2 layers.
We are now ready to initialize all the weights in each transformer layer! Let’s create a set of parameters for each layer and store them in our dictionary.
Finally, we can now write the code for the transformer encoder!
The Model is Just a FunctionPermalink
One thing we quickly notice about JAX is that everything is a function — including models. This is very different from torch, where we usually look at the model as a composition of objects (nn.Module
s). So we’ll write the forward pass as just a function. The ViT function will take the ViT parameters and an image as input, that is:
As we can see, the parameters are “outside” of the model. Before writing the actual code for the ViT, we come to another special feature o JAX: parallelization. Thanks to JAX native vmap
function, we’ll just pretend there is no batch dimension and then use vmap
to automagically handle batches. This is a great improvement as we don’t have to reason in one additional dimension and there will be no need for stuff like batch,sequence,dim = input.shape
(unlike torch.) So from now on, we’ll just ignore the batch dimension.
Don’t forget that each transformer block is nothing but a function over the model parameters and an input. For the MLP, we just perform an up and down projection with a Relu activation function in the middle. Notice that we will get the input parameters from the dictionary we created earlier.
Now self attention, the only catch here is to project into multiple heads and then concatenate back
Finally, we can assemble attention and mlps into a transformer block.
Before feeding the first image to the model, we need one more additional step to transform an input image into a sequence of patches. To do that, we use einops, which offers a highly expressive interface to reshape tensors. Another way would be applying convolutions but here we are just using bare JAX code so we get a sequence of tokens from an image like this:
With this, we are ready to go! The final transformer then works by:
- reshaping the image into patches
- projecting patches into tokens
- adding a class token and positional embeddings
- looping through a stack of transformer blocks
- applying the final classification head
Let’s implement these steps:
Let’s test it on a random input:
As you may have noticed, the random input is just an image without a batch dimension. Let’s see how we can add a batch dimension without modifying the code.
Vectorized Mapping with vmap
Permalink
As anticipated, before jumping into training, we’ll look at one of the coolest features JAX offers: vmap
. This allows you to vectorize your functions, meaning you can apply them over batches of data without writing explicit loops. In a way, it’s like automatic batching. You write a function that works on a single example, and vmap
will apply it to all examples in a batch in one go.
For example, if you have a function that processes a single image, you can turn it into a function that processes an entire batch of images with just one line:
This can come in handy when applying the model over a batch of data. This means that we can run one pass of our transformer over a batch of images very easily. Let’s try:
Let’s apply vmap
. We need to map each input in the batch (first dimension is 0) to all the parameters (second dimension in None
).
Actually, vmap
can do way more than this. I recommend this blog for an overview.
Loss FunctionPermalink
Next up is the loss function. We’ll use the Cross-Entropy Loss, which is a standard choice for classification tasks.
Dataset Loading (Stealing From PyTorch)Permalink
For dataset loading, I’m going to steal some code from PyTorch. PyTorch’s data utilities work really well, and since this isn’t a post about data loading, we’ll skip the hassle of reinventing the wheel here.
Let’s code a simple evaluation function that loops over the test data and computes accuracy
Training and Just in Time compilationPermalink
Before we dive into training, we meet another cool feature of JAX: jit
, that is, just in time compilation. One of JAX’s biggest selling points is its ability to automatically compile and optimize your code using just-in-time (JIT) compilation. With JAX, you can wrap your functions in jax.jit()
to make them faster by turning them into optimized code. It’s a one-liner and can massively speed up your training loop. In essence, jit
lets you write Python code, and JAX will magically optimize it behind the scenes. It’s not even hard to use, so there’s no reason not to take advantage of it!
Here’s how you can JIT-compile your training step:
Parameter updatesPermalink
This is where JAX differs most from PyTorch. In PyTorch, you call .backward()
on your loss, and it handles everything, i.e. computes loss and gradients, that you’ll find stored in your model parameters. In JAX, you need to manually compute gradients and update parameters yourself, which gives you a more hands-on experience with the inner workings of optimization.
To perform gradient descent, we’ll compute the gradient of the loss with respect to the parameters. In JAX, you can do this using the jax.values_and_grad
function:
What will the gradients look like ? The gradients are just going to be a dictionary (pytree) with the same keys as the model parameters, but instead of holding the parameters, they will hold the gradients.
We now have a dictionary of gradients that mirrors the structure of our parameters. To update the parameters, we’ll perform simple gradient descent. The update rule is:
Where:
- is the parameter we’re updating, in our case the dictionary.
- is the learning rate.
- is the gradient of the loss with respect to the parameter, in our case the gradients dictionary.
JAX has some great libraries for optimization, like optax
, but for simplicity, we’ll just manually update the parameters using vanilla SGD. Notice that to do this we’d have to go throught the dictionary and update all values that have the same key. Fortunately, JAX has a function that does that for us: jax.tree.map
.
We just have to tell the gradient descent rule:
Putting everything together, the training step will look like this:
Finally, let’s train the model. We don’t expect any special results because we training without any optimization and with a super small model. Also, I’ll only train for 30 epochs here, but you can let it go on for longer.
Hope you enjoyed this, please reach me at https://alessiodevoto.github.io/ if you have any questions or find inconsistencies!
Thanks to Rahil for spotting a bug in the attention shapes!