The GFlowNets and Amortized Marginalization Tutorial
(see here for the more comprehensive GFlowNet tutorial, although the stuff below might be useful to first provide a gist of core ideas behind the ability of GFlowNets to learn to sample and marginalize over distributions for which exactly doing these tasks is intractable)
Consider a set of intractable expectations that we would like to approximate (for any )
where is a rich variable taking an exponential number of values and we know how to sample from .
We could train a neural net with input , stochastic target and MSE loss
where and training estimator with parameters and the stochastic gradients would make converge to if it has enough capacity and is trained long enough.
For any new , we would then have an amortized estimator which in one pass through the network would give us an approximation of the intractable sum. We can consider this an efficient alternative to doing a Monte-Carlo approximation
which would require a potentially large number of samples and computations of for each at run-time, especially if is a rich multimodal distribution (for which just a few samples does not give us a good estimator of the expectation).
Besides the advantage of faster run-time, a crucial potential advantage of the amortized version is that it could benefit from generalizable structure in the product : if observing a training set of pairs can allow us to generalize to new pairs, then we may not need to train with an exponential number of examples before it captures that generalizable structure and provides good answers (i.e., ) on new ‘s. The ability to generalize like this is actually what explains the remarkable success of machine learning (and in particular deep learning) in vast set of modern AI applications.
When we do not have a that we can sample from easily, we can in principle use MCMC methods, that form chains of samples of whose distributions converge to the desired , and where the next sample is generally obtained from the previous one by a small stochastic change that favours increases in . Unfortunately, when the modes of occupy a small volume (i.e. throwing darts does not find them) and these modes are well-separated (by low-probability regions), especially in high dimension, it tends to take exponential time to mix from one mode to another. However, this is leaving money on the table: the attempts actually contain information one could use, using ML, to guess where the yet unseen modes might be, given where some of the modes (those already observed), as illustrated below.
The above figure illustrates that if we have already the three modes in red, we may guess the presence of a 4th mode on the top-right grid point, because the first three modes seem to align on a grid. The existence of such generalization structure is why amortized ML samplers can potentially do much better MCMC samplers.
Let us consider the situation where we do not have a handy and our objective is just to approximate a set of intractable sums (for any )
where we have the constraint that and . This may be useful to estimate a normalization constant for energy-based models or Bayesian posteriors (where =parameters and =data). Hence we may also be interested in the sampling policy
Now, the idea of GFN losses is based on the notion of a constraint that we would like to be true:
We can define estimators and and train them with a loss such as
or with an interpretation of as unnormalized probabilities that we want well calibrated in the log-domain,
where are sampled from a training distribution that has full support. It can then be shown (main GFN theorem, see NeurIPS paper and GFN Foundations paper) that with and with enough capacity and trained for long enough they both converge to their desired value:
To make the notion of intractable sum more concrete, it is good to think of (and potentially as well) as a compositional object, like a subset of variable-value pairs or a graph. To make this more concrete, we think of compositional objects being constructed sequentially through a series of steps where a new piece is inserted at each step. The sampling policy then samples these steps (which we call actions) and after each step we get a partially constructed object which we call a GFN state. The sequence of such states and actions forms a GFN trajectory. In the basic GFN framework, these transitions are deterministic (i.e. a state and an action determine exactly the next state) because they are not happening in an external environment but are part of the internal computation of a sampler. In addition, the GFN math as it currently stands (and it can probably be extended) assumes that each step is constructive, i.e., we cannot return to the same partially constructed object twice, which means that the set of all possible trajectories forms a directed acyclic graph (DAG). Because the transitions are deterministic, we can specify a trajectory with a sequence of states (we could have equally chosen an initial state and a sequence of actions). A special action also needs to be defined to declare that a state can be generated as a fully constructed object . The policy is now specified by a forward transition distribution which specifies how to generate each constructive step, and we are interested in parametrizing and learning this .
Besides the sequential nature of the generative process for , an interesting complication is that there may be many ways (in fact exponentially many trajectories) to construct from some starting point and context . This means that a partially construct object, i.e., a state , may have multiple parents for which an action exists that leads to . Otherwise, the DAG is a tree, which makes the math and calculations much simpler. But when it is not a tree, it turns out to be convenient to consider and parametrize a backward transition probability function which is consistent with that DAG and the associated . The constraint in Eq. (1) can be reformulated in several ways, in particular what is called the detailed balance constraint:
where is called the flow at state and plays a role similar to above, i.e., it is an intractable sum, and there is a starting state from which a trajectory is initiated, as well as a constraint that the flow into a terminal state equals . Similarly to the simpler case above, this can be turned into a training loss that we want to minimize, but now over all pairs or over all pairs. We also note that the initial flow , i.e., the initial flow provides us with the normalizing constant:
Let us consider the special case where is a latent parameter, i.e., in a Bayesian setting, and is the available data. Then we can define the reward function
from the parameter prior and the data likelihood . Training a GFN with a policy and initial flow provides us with an approximate sampler for the posterior over parameters given data as well as an estimator of the normalizing constant of the Bayesian posterior. Hence, we have used amortization to turn a tractable quantity (prior times likelihood) into estimators of these intractable quantities. We get a fast sampler for the posterior with no need for a Markov chain going through a large number of candidate samples, from which we can generate many samples in an iid fashion that are likely to visit the larger modes of the posterior.
To make computations more ML-friendly while training the GFN, we can note that the GFN squared loss objectives naturally lend themselves to the case where the reward or log-reward is stochastic and an unbiased estimator of the true reward. For example, we can typically decompose the overall dataset log-likelihood into a sum of per-example or per-minibatch terms, and we can introduce a multiplicative correction to account for the prior :
This makes it possible to train the GFN posterior estimator using SGD on single examples or minibatches, which is the state-of-the-art to train deep nets.
In addition to training a GFN to represent the parameter posterior , one could train it to estimate jointly the predictive posterior , where is formed of pairs, contains the -part of and the -part. This could be achieved with the same reward but with a generative policy that first samples given via and then samples the parameters from . In the case of exchangeable examples, the pairs in are iid given , so that
where and we are training (with the factors above) a predictive posterior that can predict the output for a new experiment given all the past experimental results .
A causal model can be seen as a family of distributions indexed by interventions, where all the distributions in the family share the same parameters specifying causal dependencies between random variables (but differ on which variables were intervened and what were the values imposed on these variables). One of the main challenges in learning a causal model is that discovering the causal graph, a directed acyclic graph (DAG) with one node per random variable and a directed edge between a random variable and one of its direct causes. One problem is that the number of such DAGs is super-exponential in the number of variables. Another is identifiability: if we only observe data from one member of that family (typically the default of no intervention), then in general there are many DAGs that are compatible with the data (they are collectively called the Markov Equivalence Class or MEC), even as the amount of data goes to infinity. On the other hand, we may have the opportunity to observe the outcomes of many interventions (also called experimental results), and then the ambiguity about the correct causal graph should decrease. The MEC is interesting but does not cover the case where the amount of data is finite and/or we observe different experiments. A more general description of the ambiguity regarding the causal graph is given by the Bayesian posterior on the causal graph, given the available information (data and possibly other sources of information which may constrain the causal graph). However, this is a highly multimodal distribution in a discrete space, which makes it difficult to represent and even learn it. GFNs have been used to represent and learn that distribution, taking advantage of the simple ideas in the previous section and the abilities of GFNs to learn distributions over graphs: see Deleu et al, UAI’2022.
An appealing theoretical framework for defining the objective of an experiment is that of information gain: how much information about a random variable of interest can we expect to gain through the experiment? Experimental design could be driven by this information gain as as reward. Note that we can replace the word “experiment” by “action” to start thinking about exploration and information seeking in reinforcement learning, and such a framework also applies to in the context of active learning or interactive learning more generally. In general, information gain may not be the only effect of an action, though. For example there could be risks involved. We can however incorporate a simple notion of cost or budget by considering the information gain per unit of cost incurred.
Information gain can be measured in principle by the mutual information between the outcome of an experiment (a random variable since the experiment has not taken place yet) and the variable of interest (about which we seek to gain information), given the experimental specification and any other knowledge (including data) we may already have. In the simplest and purely unsupervised knowledge-seeking scenario, the variable of interest may be the causal model explaining the outcomes of experiments. In a more targeted scenario, for example in drug discovery, it is the set of molecules that have a particular property of interest (e.g., affinity with a target protein is above a threshold and toxicity is below a threshold and synthesis cost is below a threshold). Let us write for the experimental outcome, for the experiment specification, and for the variable of interest about which we seek to gain information. Then the reward function for our experimental design could be
where is our dataset of already acquired experimental results and any other constraint we want to exploit to condition the probabilities. With typically being a much higher-dimensional object than , the second formulation on the right tends to be more practical numerically. We can then use the ideas about amortization and GFNs to estimate (a) the numerator and denominator probabilities (b) the sampler for the joint , and (c) an estimator of the MI itself, as a function of . In the case where , we can follow the approach outlined in section 4 to estimate the predictive posterior , while is given by the generative likelihood function. That approach also provides a sampler by combining the generative likelihood and the parameter posterior sampler . As for the estimator of itself, one possibility is to train a neural network to amortize the expected value over given using the samples from (b) and the estimators for the probabilities in the log-prob ratio from (a) and a squared loss
where is sampled from a dataset or a generative model of inputs, and . With enough capacity and training time, will converge to . What is especially interesting if is a high-dimensional space is that it generally won’t be necessary to see more than one value of and for each value of in order to train , as usual in supervised learning. This can work if it is possible for to generalize from the triplets used to train it.