Getting Started
Installation𝞡
Install from PyPI with pip:
pip install posteriors
Why UQ?𝞡
Uncertainty quantification allows for informed decision making by averaging over multiple plausible model configurations rather than relying on a single point estimate. Thus providing a coherent framework for detecting out of distribution inputs and continual learning.
For more info on the utility of UQ, check out our blog post introducing posteriors!
Quick Start𝞡
posteriors is a Python library for uncertainty quantification and machine learning
that is designed to be easy to use, flexible and extensible. It is built on top
of PyTorch and provides a range of
tools for probabilistic modelling, Bayesian inference, and online learning.
Enough smalltalk, let's train a simple Bayesian neural network using posteriors:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn, utils, func
import torchopt
import posteriors
dataset = MNIST(root="./data", transform=ToTensor(), download=True)
train_loader = utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
num_data = len(dataset)
classifier = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 10))
params = dict(classifier.named_parameters())
def log_posterior(params, batch):
images, labels = batch
images = images.view(images.size(0), -1)
output = func.functional_call(classifier, params, images)
log_post_val = (
-nn.functional.cross_entropy(output, labels)
+ posteriors.diag_normal_log_prob(params) / num_data
)
return log_post_val, output
transform = posteriors.vi.diag.build(
log_posterior, torchopt.adam(), temperature=1 / num_data
) # Can swap out for any posteriors algorithm
state = transform.init(params)
for batch in train_loader:
state, aux = transform.update(state, batch)
buildis a function that loadsconfig_argsinto theinitandupdatefunctions and stores them within thetransforminstance. Theinitandupdatefunctions then conform to a preset signature allowing for easy switching between algorithms.stateis aNamedTupleencoding the state of the algorithm, includingparamsandauxattributes.initconstructs the iteration-varyingstatebased on the model parametersparams.updateupdates thestatebased on a newbatchof data and also returns any auxiliary information from the model call.
We've here used posteriors.vi.diag but we could easily swap to any of the other
posteriors algorithms such as posteriors.laplace.diag_fisher or
posteriors.sgmcmc.sghmc
I want more!
The Visualizing VI and SGHMC tutorial provides
a walkthrough for a simple example demonstrating how to use posteriors and easily
switch between algorithms.
posteriors expects log_posterior to take a certain form, learn more in the
constructing log posteriors page.
Our API documentation provides detailed descriptions for all
of the posteriors algorithms and utilities.
PyTrees𝞡
The internals of posteriors rely on optree to
apply functions across arbitrary PyTrees of tensors (i.e. TensorTrees). For example:
params_squared = optree.tree_map(lambda x: x**2, params)
params, where params can be a
dict, list, tuple, or any other PyTree.
posteriors also provides a posteriors.flexi_tree_map function that allows for in-place support:
params_squared = optree.flexi_tree_map(lambda x: x**2, params, inplace=True)
torch.func𝞡
Instead of using torch's more common loss.backward() style automatic differentiation,
posteriors uses a functional approach, via torch.func.grad and friends. The functional
approach is easier to test, composes better with other tools and importantly for posteriors
it makes for code that is closer to the mathematical notation.
For example, the gradient of a function f with respect to x can be computed as:
grad_f_x = torch.func.grad(f)(x)
f is a function that takes x as input and returns a scalar output. Again,
x can be a dict, list, tuple, or any other PyTree with torch.Tensor leaves.
Friends𝞡
Compose posteriors with wonderful tools from the torch ecosystem
- Define priors and likelihoods with
torch.distributions.
Remember to setvalidate_args=Falseand construct the log posterior accordingly. torchoptfor functional optimizers.transfomersfor open source models.lightningfor logging and device management.
Check out thelightningintegration tutorial.
Additionally, the functional transform interface used in posteriors is strongly
inspired by frameworks such as optax and
blackjax.
As well as other UQ libraries fortuna,
laplace, numpyro,
pymc and uncertainty-baselines.