Skip to content

Laplace Dense Fisher𝞡

posteriors.laplace.dense_fisher.build(log_posterior, per_sample=False, init_prec=0.0) 𝞡

Builds a transform for dense empirical Fisher information Laplace approximation.

The empirical Fisher is defined here as: $$ F(θ) = \sum_i ∇_θ \log p(y_i, θ | x_i) ∇_θ \log p(y_i, θ | x_i)^T $$ where \(p(y_i, θ | x_i)\) is the joint model distribution (equivalent to the posterior up to proportionality) with parameters \(θ\), inputs \(x_i\) and labels \(y_i\).

More info on empirical Fisher matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
per_sample bool

If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample.

False
init_prec Tensor | float

Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float.

0.0

Returns:

Type Description
Transform

Empirical Fisher information Laplace approximation transform instance.

Source code in posteriors/laplace/dense_fisher.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def build(
    log_posterior: LogProbFn,
    per_sample: bool = False,
    init_prec: torch.Tensor | float = 0.0,
) -> Transform:
    """Builds a transform for dense empirical Fisher information
    Laplace approximation.

    The empirical Fisher is defined here as:
    $$
    F(θ) = \\sum_i ∇_θ \\log p(y_i, θ | x_i) ∇_θ \\log p(y_i, θ | x_i)^T
    $$
    where $p(y_i, θ | x_i)$ is the joint model distribution (equivalent to the posterior
    up to proportionality) with parameters $θ$, inputs $x_i$ and labels $y_i$.

    More info on empirical Fisher matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and
    their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        per_sample: If True, then log_posterior is assumed to return a vector of
            log posteriors for each sample in the batch. If False, then log_posterior
            is assumed to return a scalar log posterior for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_posterior to be per sample.
        init_prec: Initial precision matrix.
            If it is a float, it is defined as an identity matrix
            scaled by that float.

    Returns:
        Empirical Fisher information Laplace approximation transform instance.
    """
    init_fn = partial(init, init_prec=init_prec)
    update_fn = partial(update, log_posterior=log_posterior, per_sample=per_sample)
    return Transform(init_fn, update_fn)

posteriors.laplace.dense_fisher.DenseLaplaceState 𝞡

Bases: NamedTuple

State encoding a Normal distribution over parameters, with a dense precision matrix

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

prec Tensor

Precision matrix of the Normal distribution.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/laplace/dense_fisher.py
57
58
59
60
61
62
63
64
65
66
67
68
69
class DenseLaplaceState(NamedTuple):
    """State encoding a Normal distribution over parameters,
    with a dense precision matrix

    Attributes:
        params: Mean of the Normal distribution.
        prec: Precision matrix of the Normal distribution.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    prec: torch.Tensor
    aux: Any = None

posteriors.laplace.dense_fisher.init(params, init_prec=0.0) 𝞡

Initialise Normal distribution over parameters with a dense precision matrix.

Parameters:

Name Type Description Default
params TensorTree

Mean of the Normal distribution.

required
init_prec Tensor | float

Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float.

0.0

Returns:

Type Description
DenseLaplaceState

Initial DenseLaplaceState.

Source code in posteriors/laplace/dense_fisher.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def init(
    params: TensorTree,
    init_prec: torch.Tensor | float = 0.0,
) -> DenseLaplaceState:
    """Initialise Normal distribution over parameters
    with a dense precision matrix.

    Args:
        params: Mean of the Normal distribution.
        init_prec: Initial precision matrix.
            If it is a float, it is defined as an identity matrix
            scaled by that float.

    Returns:
        Initial DenseLaplaceState.
    """

    if is_scalar(init_prec):
        num_params = tree_size(params)
        init_prec = init_prec * torch.eye(num_params, requires_grad=False)

    return DenseLaplaceState(params, init_prec)

posteriors.laplace.dense_fisher.update(state, batch, log_posterior, per_sample=False, inplace=False) 𝞡

Adds empirical Fisher information matrix of covariance summed over given batch.

Parameters:

Name Type Description Default
state DenseLaplaceState

Current state.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised)

required
per_sample bool

If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample.

False
inplace bool

If True, the state is updated in place. Otherwise, a new state is returned.

False

Returns:

Type Description
DenseLaplaceState

Updated DenseLaplaceState.

Source code in posteriors/laplace/dense_fisher.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def update(
    state: DenseLaplaceState,
    batch: Any,
    log_posterior: LogProbFn,
    per_sample: bool = False,
    inplace: bool = False,
) -> DenseLaplaceState:
    """Adds empirical Fisher information matrix of covariance summed over
    given batch.

    Args:
        state: Current state.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
        per_sample: If True, then log_posterior is assumed to return a vector of
            log posteriors for each sample in the batch. If False, then log_posterior
            is assumed to return a scalar log posterior for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_posterior to be per sample.
        inplace: If True, the state is updated in place. Otherwise, a new
            state is returned.

    Returns:
        Updated DenseLaplaceState.
    """
    if not per_sample:
        log_posterior = per_samplify(log_posterior)

    with torch.no_grad(), CatchAuxError():
        fisher, aux = empirical_fisher(
            lambda p: log_posterior(p, batch), has_aux=True, normalize=False
        )(state.params)

    if inplace:
        state.prec.data += fisher
        return state._replace(aux=aux)
    else:
        return DenseLaplaceState(state.params, state.prec + fisher, aux)

posteriors.laplace.dense_fisher.sample(state, sample_shape=torch.Size([])) 𝞡

Sample from Normal distribution over parameters.

Parameters:

Name Type Description Default
state DenseLaplaceState

State encoding mean and precision matrix.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from the Normal distribution.

Source code in posteriors/laplace/dense_fisher.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def sample(
    state: DenseLaplaceState,
    sample_shape: torch.Size = torch.Size([]),
) -> TensorTree:
    """Sample from Normal distribution over parameters.

    Args:
        state: State encoding mean and precision matrix.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from the Normal distribution.
    """
    samples = torch.distributions.MultivariateNormal(
        loc=torch.zeros(state.prec.shape[0], device=state.prec.device),
        precision_matrix=state.prec,
        validate_args=False,
    ).sample(sample_shape)
    samples = samples.flatten(end_dim=-2)  # ensure samples is 2D
    mean_flat, unravel_func = tree_ravel(state.params)
    samples += mean_flat
    samples = torch.vmap(unravel_func)(samples)
    samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples)
    return samples