Getting Started
Installation𝞡
Install from PyPI with pip
:
pip install posteriors
Why UQ?𝞡
Uncertainty quantification allows for informed decision making by averaging over multiple plausible model configurations rather than relying on a single point estimate. Thus providing a coherent framework for detecting out of distribution inputs and continual learning.
For more info on the utility of UQ, check out our blog post introducing posteriors
!
Quick Start𝞡
posteriors
is a Python library for uncertainty quantification and machine learning
that is designed to be easy to use, flexible and extensible. It is built on top
of PyTorch and provides a range of
tools for probabilistic modelling, Bayesian inference, and online learning.
Enough smalltalk, let's train a simple Bayesian neural network using posteriors
:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn, utils, func
import torchopt
import posteriors
dataset = MNIST(root="./data", transform=ToTensor(), download=True)
train_loader = utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
num_data = len(dataset)
classifier = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 10))
params = dict(classifier.named_parameters())
def log_posterior(params, batch):
images, labels = batch
images = images.view(images.size(0), -1)
output = func.functional_call(classifier, params, images)
log_post_val = (
-nn.functional.cross_entropy(output, labels)
+ posteriors.diag_normal_log_prob(params) / num_data
)
return log_post_val, output
transform = posteriors.vi.diag.build(
log_posterior, torchopt.adam(), temperature=1 / num_data
) # Can swap out for any posteriors algorithm
state = transform.init(params)
for batch in train_loader:
state = transform.update(state, batch)
build
is a function that loadsconfig_args
into theinit
andupdate
functions and stores them within thetransform
instance. Theinit
andupdate
functions then conform to a preset signature allowing for easy switching between algorithms.state
is aNamedTuple
encoding the state of the algorithm, includingparams
andaux
attributes.init
constructs the iteration-varyingstate
based on the model parametersparams
.update
updates thestate
based on a newbatch
of data.
We've here used posteriors.vi.diag
but we could easily swap to any of the other
posteriors
algorithms such as posteriors.laplace.diag_fisher
or
posteriors.sgmcmc.sghmc
I want more!
The Visualizing VI and SGHMC tutorial provides
a walkthrough for a simple example demonstrating how to use posteriors
and easily
switch between algorithms.
posteriors
expects log_posterior
to take a certain form, learn more in the
constructing log posteriors page.
Our API documentation provides detailed descriptions for all
of the posteriors
algorithms and utilities.
PyTrees𝞡
The internals of posteriors
rely on optree
to
apply functions across arbitrary PyTrees of tensors (i.e. TensorTrees). For example:
params_squared = optree.tree_map(lambda x: x**2, params)
params
, where params
can be a
dict
, list
, tuple
, or any other PyTree.
posteriors
also provides a posteriors.flexi_tree_map
function that allows for in-place support:
params_squared = optree.flexi_tree_map(lambda x: x**2, params, inplace=True)
torch.func
𝞡
Instead of using torch
's more common loss.backward()
style automatic differentiation,
posteriors
uses a functional approach, via torch.func.grad
and friends. The functional
approach is easier to test, composes better with other tools and importantly for posteriors
it makes for code that is closer to the mathematical notation.
For example, the gradient of a function f
with respect to x
can be computed as:
grad_f_x = torch.func.grad(f)(x)
f
is a function that takes x
as input and returns a scalar output. Again,
x
can be a dict
, list
, tuple
, or any other PyTree with torch.Tensor
leaves.
Friends𝞡
Compose posteriors
with wonderful tools from the torch
ecosystem
- Define priors and likelihoods with
torch.distributions
.
Remember to setvalidate_args=False
and construct the log posterior accordingly. torchopt
for functional optimizers.transfomers
for open source models.lightning
for logging and device management.
Check out thelightning
integration tutorial.
Additionally, the functional transform interface used in posteriors
is strongly
inspired by frameworks such as optax
and
blackjax
.
As well as other UQ libraries fortuna
,
laplace
, numpyro
,
pymc
and uncertainty-baselines
.