Title: Categorical Reparameterization with Gumbel-Softmax
Authors: Eric Jang, Shixiang Gu, Ben Poole
Published: 3rd November 2016 (Thursday) @ 19:48:08
Link: http://arxiv.org/abs/1611.01144v5
Abstract
Categorical variables are a natural choice for representing discrete structure in the world. However, stochastic neural networks rarely use categorical latent variables due to the inability to backpropagate through samples. In this work, we present an efficient gradient estimator that replaces the non-differentiable sample from a categorical distribution with a differentiable sample from a novel Gumbel-Softmax distribution. This distribution has the essential property that it can be smoothly annealed into a categorical distribution. We show that our Gumbel-Softmax estimator outperforms state-of-the-art gradient estimators on structured output prediction and unsupervised generative modeling tasks with categorical latent variables, and enables large speedups on semi-supervised classification.
Phind explains the Gumbel-Softmax
The Gumbel-Softmax is a continuous distribution that can be smoothly annealed into a categorical distribution. Itâs a technique used to handle discrete data generated from a categorical distribution, making it useful in tasks like Natural Language Processing (NLP) where the sampling of words, characters, or phonemes is discrete Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch.
The Gumbel-Softmax is based on the Gumbel-Max trick, which operates primarily in log-space, thereby avoiding potentially nasty numerical over/under-flow errors and unexpected/incorrect sampling behavior. It also bypasses the need for marginalization (i.e., exp-sum), which can be expensive for a large number of categories The Gumbel-Max Trick Explained. Softmaxâs slicker sibling. by Leonard Tang The Startup Medium.
The Gumbel-Softmax distribution is smooth for λ, and therefore has a well-defined gradient with respect to the parameter Ï. By replacing categorical samples with Gumbel-Softmax samples, we can use backpropagation to compute gradients Gumbel-Softmax - Niansong Zhang.
In the Gumbel-Softmax approach, we still combine the log probabilities with Gumbel noise, but now we take the softmax over the samples instead of the argmax. The parameter λ (Lambda) is the softmax temperature parameter which allows us to control how closely the Gumbel-softmax distribution approximates the categorical distribution. If λ is very small, then we get very close to a quantized categorical sample, and conversely, the Gumbel-softmax distribution becomes more uniformly distributed as λ increases Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch.
Hereâs a simple implementation of Gumbel-Softmax in PyTorch:
import torch
def gumbel_softmax(logits, temperature):
y = logits + torch.rand_like(logits) * 1e-20
return torch.nn.functional.softmax(y / temperature, dim=-1)
In this function, logits
are the raw, unnormalized scores output by your model for each class, and temperature
is the softmax temperature parameter. The function adds a small amount of noise to the logits, applies the softmax function, and returns the result Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch.
The Gumbel-Softmax trick can prove super useful in discrete sampling tasks, which used to be handled in other ways. For example, NLP tasks are almost necessarily discrete â like the sampling of words, characters, or phonemes Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch.