Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch
Excerpt
Guide on Gumbel-Softmax in DL focusing on discrete operations, PyTorch implementation, and future prospects for optimization.
Training deep learning models has never been easier. You just define the architecture and loss function, sit back, and monitor, well, at least in simple cases. Some architectures come with inherent random components. This makes the forward pass stochastic, and your model â no longer deterministic.
In deterministic models, the output of the model is fully determined by parameter values and initial conditions.
Stochastic models have inherent randomness. The same set of parameter values and initial conditions will lead to an ensemble of different outputs.
This means you canât sample the way you were because sampling from deterministic functions leads to the same results but with stochastic and its added randomness you canât achieve that, the whole sampling would become non-deterministic.
You see, the backpropagation algorithm relies on having chains of continuous functions in each layer of the neural network. A lot of Neural networks fundamentally utilize discrete operations. Since sampling from discrete space isnât the same as sampling from continuous thatâs where the Gumbel-Softmax trick comes to the rescue. It not only helps sampling from discrete space operate like continuous but it keeps the stochastic nature of the node intact while also keeping the backpropagation step viable_._Â
Letâs explore these operations with examples to gain a better understanding.
Discrete operations in Deep Learning
We use discrete sampling across many areas that involve Deep Learning. For eg, In language models, where we have sequences of words or character tokens that are being sampled, where each discrete token corresponds to a word or a character. This way weâre sampling from discrete space. Â
A sequence of word tokenizations demonstrating sampling from discrete space | Source
Another popular example is the LSTM recurrent neural network architecture. It has internal gating mechanisms that are used to learn long-term dependencies. Although these gating units are continuous, the operation w.r.t. the whole LSTM cell has some discrete nature.
One more popular example of using discrete sampling in deep learning is the seq2seq DNC architecture. Seq2Seq-DNC uses Read / Write (discrete) operations on the external memory to store encoder-decoder states, in order to support long-range dependencies.
Operation structure of DNC memory area | Source
These read/write operations are sampled using another neural network architecture. In a way, this neural network is sampling from a discrete space.
Now letâs take a look at the motivation and purpose behind Gumbel-Softmax.
Understanding Gumbel-Softmax
The problem Gumbel-Softmax addresses is working with discrete data generated from a categorical distribution. Letâs see the inner mechanics behind it.
Gumbel Max trick
Gumbel Max trick is a technique that allows sampling from categorical distribution during the forward pass of a neural network. It essentially is done by combining the reparameterization trick and smooth relaxation. Letâs look at how this works.
Sampling from a categorical distribution by taking argmax of a combination class probabilities and Gumbel noise | Source
In this technique, if we take the class probabilities and apply the logarithmic function to each, and to each of these logits we add Gumbel noise which can be sampled by taking two logs of some uniform distribution. This step is similar to that used in the Reparametrization Trick above where we add the normally distributed noise to the mean.
After combining the deterministic and stochastic parts of the sampling process, we use the argmax function to find the class that has the maximum value for each sample. The class or sample is encoded as a one-hot vector for use by the rest of the neural network.
Now we have a way of sampling from a categorical distribution, as opposed to a continuous distribution. However, we still canât backpropagate through argmax because the gradients that get out of it are 0 i.e. itâs not differentiable.
The paper [3] proposed a technique that replaces argmax with softmax. Letâs look into this.
Gumbel Softmax
Replacing argmax with softmax because softmax is differentiable(required by backpropagation) | Source: Author
In this approach, we still combine the log probabilities with Gumbel noise, but now we take the softmax over the samples instead of the argmax.Â
Lambda(λ) is the softmax temperature parameter which allows us to control how closely the Gumbel-softmax distribution approximates the categorical distribution. If lambda is very small, then we get very close to a quantized categorical sample, and conversely, the Gumbel-softmax distribution becomes more uniformly distributed as the lambda increases.
Implementation of Gumbel Softmax
In this section, weâll train a Variational Auto-Encoder on the MNIST dataset to reconstruct images. Weâll apply Gumbel-softmax in sampling from the encoder states. Letâs code!
Note: Weâll use Pytorch as our framework of choice for this implementation
Read how you can keep track of your PyTorch model training
First, letâs import the required dependencies.
import numpy as np
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import math
irange = range
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import neptune
from neptune.types import File
run = neptune.init_run(project='common/pytorch-integration',
api_token='ANONYMOUS')
Itâs always handy to define some hyper-parameters early on.
batch_size = 100
epochs = 10
temperature = 1.0
no_cuda = False
seed = 2020
log_interval = 10
hard = False
As mentioned earlier, weâll utilize MNIST for this implementation. Letâs import it.
is_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
if is_cuda:
torch.cuda.manual_seed(seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data/MNIST', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data/MNIST', train=False, transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True, **kwargs)
Now, weâll define the Gumbel-softmax sampling helper functions.
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape)
if is_cuda:
U = U.cuda()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
"""
ST-gumple-softmax
input: [*, n_class]
return: flatten --> [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)
if not hard:
return y.view(-1, latent_dim * categorical_dim)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
y_hard = (y_hard - y).detach() + y
return y_hard.view(-1, latent_dim * categorical_dim)
Next, letâs define the VAE architecture and loss function.
def loss_function(recon_x, x, qy):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) / x.shape[0]
log_ratio = torch.log(qy * categorical_dim + 1e-20)
KLD = torch.sum(qy * log_ratio, dim=-1).mean()
return BCE + KLD
class VAE_gumbel(nn.Module):
def __init__(self, temp):
super(VAE_gumbel, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
self.fc5 = nn.Linear(256, 512)
self.fc6 = nn.Linear(512, 784)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def encode(self, x):
h1 = self.relu(self.fc1(x))
h2 = self.relu(self.fc2(h1))
return self.relu(self.fc3(h2))
def decode(self, z):
h4 = self.relu(self.fc4(z))
h5 = self.relu(self.fc5(h4))
return self.sigmoid(self.fc6(h5))
def forward(self, x, temp, hard):
q = self.encode(x.view(-1, 784))
q_y = q.view(q.size(0), latent_dim, categorical_dim)
z = gumbel_softmax(q_y, temp, hard)
return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())
Time for some more hyper-parameters.
latent_dim = 30
categorical_dim = 10
temp_min = 0.5
ANNEAL_RATE = 0.00003
model = VAE_gumbel(temperature)
if is_cuda:
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
Weâll train and test in two different ways.Â
In the testing function weâll apply the reconstruction of the image, basically to test the sampling and model efficiency on an unseen sample of data.
def train(epoch):
model.train()
train_loss = 0
temp = temperature
for batch_idx, (data, _) in enumerate(train_loader):
if is_cuda:
data = data.cuda()
optimizer.zero_grad()
recon_batch, qy = model(data, temp, hard)
loss = loss_function(recon_batch, data, qy)
loss.backward()
train_loss += loss.item() * len(data)
optimizer.step()
if batch_idx % 100 == 1:
temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)
if batch_idx==0:
reconstructed_image = recon_batch.view(batch_size, 1, 28, 28)
grid_array = get_grid(reconstructed_image)
run["train_reconstructed_images/{}".format('training_reconstruction_' + str(epoch))].upload(File.as_image(grid_array))
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item()))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
run['metrics/avg_train_loss'].log(train_loss / len(train_loader.dataset))
def test(epoch):
model.eval()
test_loss = 0
temp = temperature
for i, (data, _) in enumerate(test_loader):
if is_cuda:
data = data.cuda()
recon_batch, qy = model(data, temp, hard)
test_loss += loss_function(recon_batch, data, qy).item() * len(data)
if i % 100 == 1:
temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i), temp_min)
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(batch_size, 1, 28, 28)[:n]])
grid_array = get_grid(comparison)
run["test_reconstructed_images/{}".format('test_reconstruction_' + str(epoch))].upload(File.as_image(grid_array))
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
run['metrics/avg_test_loss'].log(test_loss)
Note: Please find the utility functions used in the code excerpt above in the notebook here.
Finally, weâll define the event loop to run all the individual functions in conjunction.
for epoch in range(1, epochs + 1):
train(epoch)
test(epoch)
At the end of successful execution, you will get a reconstructed image of the MNIST samples like this:
Just by seeing the contrast between the reconstructed part and the original part tells us how well the sampling from Gumbel-softmax worked. We can see the training and test loss convergence in the plot below:
You can access the complete experiment bundled with reconstructed images here and the code above here.Â
Youâve reached the end!
The Gumbel-Softmax trick can prove super useful in discrete sampling tasks, which used to be handled in other ways. For example, NLP tasks are almost necessarily discrete â like the sampling of words, characters, or phonemes.
Future prospects
The Gumbel-softmax paper also mentioned its usefulness in Variational Autoencoders, but itâs certainly not limited to that.Â
You can apply the same technique to Binary Autoencoders and other, complex Neural Networks like Generative Adversarial Networks (GANâs). It seems limitless.
Thatâs it for now, stay tuned for more! Adios!
References
- https://www.youtube.com/watch?v=JFgXEbgcT7g Â
- https://sassafras13.github.io/GumbelSoftmax/.
- https://arxiv.org/pdf/1611.01144.pdf
- https://towardsdatascience.com/what-is-gumbel-softmax-7f6d9cdcb90e
- https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
- https://arxiv.org/abs/1611.00712
Was the article useful?
Thank you for your feedback!