Skip to content

Visualizing VI and SGHMC on a Multimodal Distribution𝞡

In this example, we'll compare vi.diag and sgmcmc.sghmc for approximating a two dimensional distribution which we can visualize.

Target distribution𝞡

We'll start by defining the target distribution, a two dimensional double well:

import torch
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import torchopt
import posteriors

torch.manual_seed(42)

def log_posterior(x, batch):
    log_prob = -torch.sum(x**4, axis=-1) / 10.0 + torch.sum(x**2, axis=-1)
    return log_prob, torch.tensor([])

Note that the log_posterior has to conform to the signature log_posterior(params, batch) -> log_prob, aux where params is a tensor of shape (batch_size, num_params) and batch is a dictionary of tensors. More info on the constructing log posteriors page.

In this simple example we don't have varying batches so we'll just ignore that input. We also don't have any auxiliary information we'd like to keep hold of, so we just return an empty tensor.

Variational Inference𝞡

Now we'll fit a diagonal Gaussian variational distribution. This is easy with posteriors:

vi_transform = posteriors.vi.diag.build(
    log_posterior, optimizer=torchopt.adam(lr=1e-2), init_log_sds=-2.0
)
n_vi_steps = 2000
vi_state = vi_transform.init(torch.zeros(2))

nelbos = []
for _ in range(n_vi_steps):
    vi_state = vi_transform.update(vi_state, None)
    nelbos.append(vi_state.nelbo.item())

Here, we've tracked the values of the negative ELBO (NELBO), let's have a look at them:

plt.plot(nelbos)
plt.ylabel("NELBO")

NELBO

Looks like it converged, but there's a fair amount of variance around the minima, maybe a Gaussian isn't a great fit for our target distribution....

SGHMC𝞡

Let's generate samples with posteriors.vi.sgmcmc.sghmc instead:

sghmc_transform = posteriors.sgmcmc.sghmc.build(log_posterior, lr=5e-2, alpha=1.0)
n_sghmc_steps = 10000
sghmc_state = sghmc_transform.init(torch.zeros(2))

samples = torch.zeros(1, 2)
log_posts = []
for _ in range(n_sghmc_steps):
    sghmc_state = sghmc_transform.update(sghmc_state, None)
    samples = torch.cat([samples, sghmc_state.params.unsqueeze(0)], axis=0)
    log_posts.append(sghmc_state.log_posterior.item())

Here we've tracked the values of the log posterior, let's have a look at them:

plt.plot(log_posts)
plt.ylabel("SGHMC Log Posterior")

SGHMC

Certainly some exploration going on! We can see that the log posterior evaluations vary quite a lot during sampling but in the most part stay significantly higher that the initialization value. Hopefully this means we're exploring the modes of the distribution, but we'll find out for sure when we visualize the samples.

Visualizing𝞡

Time to visualize the learnt variational distribution and the samples generated by SGHMC:

Code to plot the distributions
lim = 4
x = torch.linspace(-lim, lim, 1000)
X, Y = torch.meshgrid(x, x)
Z = torch.vmap(log_posterior, in_dims=(0, None))(torch.stack([X, Y], axis=-1), None)[0]
plt.contourf(X, Y, Z, levels=50, cmap="Purples", alpha=0.3, zorder=-1)

mean = vi_state.params
sd_diag = torch.exp(vi_state.log_sd_diag)
Z_gauss = torch.vmap(
    lambda z: -torch.sum(torch.square((z - mean) / sd_diag), axis=-1) / 2.0,
)(torch.stack([X, Y], axis=-1))

plt.contour(X, Y, Z_gauss, levels=5, colors="black", alpha=0.5)
sghmc_samps = plt.scatter(
    samples[:, 0], samples[:, 1], c="r", s=0.5, alpha=0.5, label="SGHMC Samples"
)

vi_legend_line = mlines.Line2D(
    [], [], color="black", label="VI", alpha=0.5, linestyle="--"
)
plt.legend(handles=[vi_legend_line, sghmc_samps])
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)

Double Well

We can see the variational Gaussian ignores the multiple modes, but SGHMC explores them well.

Note

The raw code for this example can be found in the repo at examples/double_well.py.