Gotchas
torch.no_grad𝞡
If you find yourself running out of memory when using torch.func.grad and friends,
it might be because torch is trying to accumulate gradients through your
torch.func.gradcalls. To prevent this, somewhat counterintuitively,
wrap your code in torch.no_grad:
with torch.no_grad():
grad_f_x = torch.func.grad(f)(params, batch)
Don't worry, torch.no_grad won't prevent the gradients being calculated correctly
in the functional call. However, torch.inference_mode will turn autograd off
altogether. More info in the torch.func docs.
validate_args=False in torch.distributions𝞡
posteriors uses torch.vmap internally to vectorize over functions, for cool things like
per-sample gradients.
The validate_args=True control flows in torch.distributions do not compose with the
control flows in torch.vmap. So it is recommended to set validate_args=False when
using torch.distributions in posteriors:
import torch
from torch.distributions import Normal
torch.vmap(lambda x: Normal(0., 1.).log_prob(x))(torch.arange(3))
# RuntimeError: vmap: It looks like you're attempting to use a
# Tensor in some data-dependent control flow. We don't support that yet,
# please shout over at https://github.com/pytorch/functorch/issues/257 .
torch.vmap(lambda x: Normal(0., 1., validate_args=False).log_prob(x))(torch.arange(3))
# tensor([-0.9189, -1.4189, -2.9189])
Auxiliary information𝞡
posteriors enforces log_posterior and log_likelihood functions to have a
log_posterior(params, batch) -> log_prob, aux signature, where the second element
contains any auxiliary information. If you don't have any auxiliary information, just
return an empty tensor:
def log_posterior(params, batch):
log_prob = ...
return log_prob, torch.tensor([])
More info in the constructing log posteriors page.
inplace𝞡
All posteriors algorithms have an update function with signature
update(state, batch, inplace=False) -> state, aux1. The inplace
argument can be set to True to update the state in-place and save memory. However,
posteriors is functional first, so has inplace=False as the default.
state2, aux = transform.update(state, batch)
# state is not updated
state2, aux = transform.update(state, batch, inplace=True)
# state is updated and state2 is a pointer to state
When adding a new algorithm, in-place support can be achieved by modifying TensorTrees
via the flexi_tree_map function:
from posteriors.tree_utils import flexi_tree_map
new_state = flexi_tree_map(lambda x: x + 1, state, inplace=True)
As posteriors transform states are immutable NamedTuples, in-place modification of
TensorTree leaves can be achieved by modifying the data of the tensor directly with tree_insert_:
from posteriors.tree_utils import tree_insert_
tree_insert_(state.log_posterior, log_post.detach())
However, the aux component of the TransformState is not guaranteed to be a TensorTree,
and so in-place modification of aux is not supported. Using state._replace(aux=aux)
will return a state with all TensorTree pointing to the same memory as input state,
but with a new aux component (aux is not modified in the input state object).
torch.tensor with autograd𝞡
As specified in the documentation,
torch.tensor does not preserve autograd history. If you want to construct a tensor
within a differentiable function, use torch.stack instead:
def f_with_tensor(x):
return torch.tensor([x**2, x**3]).sum()
torch.func.grad(f_with_tensor)(torch.tensor(2.))
# tensor(0.)
def f_with_stack(x):
return torch.stack([x**2, x**3]).sum()
torch.func.grad(f_with_stack)(torch.tensor(2.))
# tensor(16.)
-
Assuming all other args and kwargs have been pre-configured with by the
buildfunction ↩