Skip to content

Gotchas

torch.no_grad𝞡

If you find yourself running out of memory when using torch.func.grad and friends, it might be because torch is trying to accumulate gradients through your torch.func.gradcalls. To prevent this, somewhat counterintuitively, wrap your code in torch.no_grad:

with torch.no_grad():
    grad_f_x = torch.func.grad(f)(params, batch)

Don't worry, torch.no_grad won't prevent the gradients being calculated correctly in the functional call. However, torch.inference_mode will turn autograd off altogether. More info in the torch.func docs.

validate_args=False in torch.distributions𝞡

posteriors uses torch.vmap internally to vectorize over functions, for cool things like per-sample gradients. The validate_args=True control flows in torch.distributions do not compose with the control flows in torch.vmap. So it is recommended to set validate_args=False when using torch.distributions in posteriors:

import torch
from torch.distributions import Normal

torch.vmap(lambda x: Normal(0., 1.).log_prob(x))(torch.arange(3))
# RuntimeError: vmap: It looks like you're attempting to use a
# Tensor in some data-dependent control flow. We don't support that yet, 
# please shout over at https://github.com/pytorch/functorch/issues/257 .

torch.vmap(lambda x: Normal(0., 1., validate_args=False).log_prob(x))(torch.arange(3))
# tensor([-0.9189, -1.4189, -2.9189])

Auxiliary information𝞡

posteriors enforces log_posterior and log_likelihood functions to have a log_posterior(params, batch) -> log_prob, aux signature, where the second element contains any auxiliary information. If you don't have any auxiliary information, just return an empty tensor:

def log_posterior(params, batch):
    log_prob = ...
    return log_prob, torch.tensor([])

More info in the constructing log posteriors page.

inplace𝞡

All posteriors algorithms have an update function with signature update(state, batch, inplace=False) -> state1. The inplace argument can be set to True to update the state in-place and save memory. However, posteriors is functional first, so has inplace=False as the default.

state2 = transform.update(state, batch)
# state is not updated

state2 = transform.update(state, batch, inplace=True)
# state is updated and state2 is a pointer to state

When adding a new algorithm, in-place support can be achieved by modifying TensorTrees via the flexi_tree_map function:

from posteriors.tree_utils import flexi_tree_map

new_state = flexi_tree_map(lambda x: x + 1, state, inplace=True)

As posteriors transform states are immutable NamedTuples, in-place modification of TensorTree leaves can be achieved by modifying the data of the tensor directly with tree_insert_:

from posteriors.tree_utils import tree_insert_

tree_insert_(state.log_posterior, log_post.detach())

However, the aux component of the TransformState is not guaranteed to be a TensorTree, and so in-place modification of aux is not supported. Using state._replace(aux=aux) will return a state with all TensorTree pointing to the same memory as input state, but with a new aux component (aux is not modified in the input state object).

torch.tensor with autograd𝞡

As specified in the documentation, torch.tensor does not preserve autograd history. If you want to construct a tensor within a differentiable function, use torch.stack instead:

def f_with_tensor(x):
    return torch.tensor([x**2, x**3]).sum()

torch.func.grad(f_with_tensor)(torch.tensor(2.))
# tensor(0.)

def f_with_stack(x):
    return torch.stack([x**2, x**3]).sum()

torch.func.grad(f_with_stack)(torch.tensor(2.))
# tensor(16.)

  1. Assuming all other args and kwargs have been pre-configured with by the build function