Visualizing VI and SGHMC on a Multimodal Distribution𝞡
In this example, we'll compare vi.diag
and sgmcmc.sghmc
for approximating a two dimensional distribution which we can visualize.
Target distribution𝞡
We'll start by defining the target distribution, a two dimensional double well:
import torch
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import torchopt
import posteriors
torch.manual_seed(42)
def log_posterior(x, batch):
log_prob = -torch.sum(x**4, axis=-1) / 10.0 + torch.sum(x**2, axis=-1)
return log_prob, torch.tensor([])
Note that the log_posterior
has to conform to the signature log_posterior(params, batch) -> log_prob, aux
where params
is a tensor of shape (batch_size, num_params)
and
batch
is a dictionary of tensors. More info on the
constructing log posteriors page.
In this simple example we don't have varying batches so we'll just ignore that input. We also don't have any auxiliary information we'd like to keep hold of, so we just return an empty tensor.
Variational Inference𝞡
Now we'll fit a diagonal Gaussian variational distribution. This is easy with posteriors
:
vi_transform = posteriors.vi.diag.build(
log_posterior, optimizer=torchopt.adam(lr=1e-2), init_log_sds=-2.0
)
n_vi_steps = 2000
vi_state = vi_transform.init(torch.zeros(2))
nelbos = []
for _ in range(n_vi_steps):
vi_state = vi_transform.update(vi_state, None)
nelbos.append(vi_state.nelbo.item())
Here, we've tracked the values of the negative ELBO (NELBO), let's have a look at them:
plt.plot(nelbos)
plt.ylabel("NELBO")
Looks like it converged, but there's a fair amount of variance around the minima, maybe a Gaussian isn't a great fit for our target distribution....
SGHMC𝞡
Let's generate samples with posteriors.vi.sgmcmc.sghmc
instead:
sghmc_transform = posteriors.sgmcmc.sghmc.build(log_posterior, lr=5e-2, alpha=1.0)
n_sghmc_steps = 10000
sghmc_state = sghmc_transform.init(torch.zeros(2))
samples = torch.zeros(1, 2)
log_posts = []
for _ in range(n_sghmc_steps):
sghmc_state = sghmc_transform.update(sghmc_state, None)
samples = torch.cat([samples, sghmc_state.params.unsqueeze(0)], axis=0)
log_posts.append(sghmc_state.log_posterior.item())
Here we've tracked the values of the log posterior, let's have a look at them:
plt.plot(log_posts)
plt.ylabel("SGHMC Log Posterior")
Certainly some exploration going on! We can see that the log posterior evaluations vary quite a lot during sampling but in the most part stay significantly higher that the initialization value. Hopefully this means we're exploring the modes of the distribution, but we'll find out for sure when we visualize the samples.
Visualizing𝞡
Time to visualize the learnt variational distribution and the samples generated by SGHMC:
Code to plot the distributions
lim = 4
x = torch.linspace(-lim, lim, 1000)
X, Y = torch.meshgrid(x, x)
Z = torch.vmap(log_posterior, in_dims=(0, None))(torch.stack([X, Y], axis=-1), None)[0]
plt.contourf(X, Y, Z, levels=50, cmap="Purples", alpha=0.3, zorder=-1)
mean = vi_state.params
sd_diag = torch.exp(vi_state.log_sd_diag)
Z_gauss = torch.vmap(
lambda z: -torch.sum(torch.square((z - mean) / sd_diag), axis=-1) / 2.0,
)(torch.stack([X, Y], axis=-1))
plt.contour(X, Y, Z_gauss, levels=5, colors="black", alpha=0.5)
sghmc_samps = plt.scatter(
samples[:, 0], samples[:, 1], c="r", s=0.5, alpha=0.5, label="SGHMC Samples"
)
vi_legend_line = mlines.Line2D(
[], [], color="black", label="VI", alpha=0.5, linestyle="--"
)
plt.legend(handles=[vi_legend_line, sghmc_samps])
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
We can see the variational Gaussian ignores the multiple modes, but SGHMC explores them well.
Note
The raw code for this example can be found in the repo at examples/double_well.py.