-
Notifications
You must be signed in to change notification settings - Fork 85
Amortized and guided inference roadmap
This is a set of notes to plan the way forward for the various variational / neurally guided / amortized inference efforts in WebPPL.
In general we augment our target probabilistic program with a side computation (the 'guide') that makes decisions about which distribution to actually sample from at each sample statement. The perfect guide would result in samples from the posterior, accounting for factors, and thus be a perfect importance distribution. The goal is thus to learn a guide that has minimum distance between guided distribution and target model posterior. In general we generally use guides that have continuous parameters, giving a continuous optimization problem. There are various ways to formalize (/approximate) the objective and to try to optimize it.
We've been playing with several different flavors of objective / optimization.
The objective is the ELBo (aka free energy): KL(guided, target). The expectation in the KL depends on the guide, so the gradient is a bit subtle. One approach is to derive a form for the gradient given samples from the guide, which ends up as the log likelihood ratio guided/target times guide gradient (call this version the LR approach); the other is to re-parametrize guided random choices, when possible, such that the guide looks like independent samples followed by some arithmetic (call this the reparam approach).
The connection between this approach and reinforcement learning is useful. See for example Reinforced Variational Inference.
The two can be combined, using reparam for the guided choices where it can be done. However, the gradient only propagates backward until it hits a discrete choice, so this is most useful for models with discrete choices early in the generative process and continuous choices later. (Which is the opposite of, eg, LDA.)
This big issue is the variance of the gradient estimators. Re-parametrization is reported to have much lower variance than LR.
An estimate of the baseline 'value' can be subtracted from the 'reward' to reduce variance. We should try learning a neural net for the baseline as done in the recent deepmind paper. (Should also try parallel threaded updates.)
In cases where there is a discrete choice late in the generative process, which prevents gradients form propagating back to continuous choices earlier, it is sometimes practical to simply enumerate the traces from all the values of this choice. We should make this easier to do (e.g. an Enumerate directive or ERP argument?).
If the markov blanket of a random choice is known, variance can be reduced by 'ignoring' variables outside the markov blanket of each target variable. (Does this generalize to neural guides that introduce long range dependence?)
Sometimes discrete ERPs can be replaced by continuous (eg discrete(probs) becomes dirichlet(alpha*probs)). This allows us to use the reparam method instead of the LR method, reducing variance. The main obstacle is when the discrete variable was needed in a primitive that can't handle real -- especially if
.
Can we relax or enumerate all discrete vars, and then add discrete randomization only at conditionals? i.e. we should only have to incur the variance hit for structure change.
There are a lot of approaches to variance reduction out there, and we'd probably do well to evaluate the options. Here's a list of some paper worth looking at for ideas:
- Local Expectation Gradients for Doubly Stochastic Variational Inference
- MuProp
- Deterministic policy gradients
Here the objective is the (possibly more sensible) 'evidence upper bound' (EUBo): KL(target, guided). The expectation in the KL requires posterior samples, on the bright side once we have those training traces we simply optimize the probability of the guide making the training choices (i.e. it becomes simple maximum likelihood).
This is a kind of apprenticeship training where the training examples come from a known-to-converge Monte Carlo algorithm for the target model. For example, Ritchie, et al, use a particle filter with enough particles to get good posterior samples. Could also use MCMC.
By using the guide as part of the teacher algorithm (e.g. a particle filter with guide as importance distribution) there is the potential for bootstrapping -- this needs to be evaluated.
simplest is mean-field. more interesting are various neural families, that pass information around in order to help approximate posterior dependence (that isn't prior).
LDA is tricky to do because the discrete topic-per-word variables happen late in the generative process, and there are a lot of them, resulting in very high variance for the LR method. Rao-Blackwellization helps, but isn't very feasible for generic inference. Some things to try:
- Marginalize out the topic-per-word by enumerating the
generateWord(docTopicDist,topicWordDists)
function. If the number of topics is small, this should be reasonably efficient and allow us to use the reparam method for the dirichlet samples. - Use a neural guide family that shares params across documents and words. For example, for sampling the docTopicDist, use a guide NN that gets as input the word-count-vector; for the topicPerWord, take the topicDist and the word (or enumerate it out as above). By reducing the number of params, and sharing across documents/words, this may converge faster even with the noisy estimator.
- Relax the discrete topicPerWord choice into a continuous mixtureOfTopicsPerWord vector?