Skip to content

Laplace Dense GGN𝞡

posteriors.laplace.dense_ggn.build(forward, outer_log_likelihood, init_prec=0.0) 𝞡

Builds a transform for a Generalized Gauss-Newton (GGN) Laplace approximation.

Equivalent to the (non-empirical) Fisher information matrix when the outer_log_likelihood is exponential family with natural parameter equal to the output from forward.

forward should output auxiliary information (or torch.tensor([])), outer_log_likelihood should not.

The GGN is defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss (negative log-likelihood) that maps the output of \(f\) to a scalar output.

More info on Fisher and GGN matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.

Parameters:

Name Type Description Default
forward ForwardFn

Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information.

required
outer_log_likelihood OuterLogProbFn

A function that takes the output of forward and batch then returns the log likelihood of the model output, with no auxiliary information.

required
init_prec TensorTree | float

Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float.

0.0

Returns:

Type Description
Transform

GGN Laplace approximation transform instance.

Source code in posteriors/laplace/dense_ggn.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def build(
    forward: ForwardFn,
    outer_log_likelihood: OuterLogProbFn,
    init_prec: TensorTree | float = 0.0,
) -> Transform:
    """Builds a transform for a Generalized Gauss-Newton (GGN)
    Laplace approximation.

    Equivalent to the (non-empirical) Fisher information matrix when
    the `outer_log_likelihood` is exponential family with natural parameter equal to
    the output from `forward`.

    `forward` should output auxiliary information (or `torch.tensor([])`),
    `outer_log_likelihood` should not.

    The GGN is defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
    is a loss (negative log-likelihood) that maps the output of $f$ to a scalar output.

    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and
    their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).

    Args:
        forward: Function that takes parameters and input batch and
            returns a forward value (e.g. logits), not reduced over the batch,
            as well as auxiliary information.
        outer_log_likelihood: A function that takes the output of `forward` and batch
            then returns the log likelihood of the model output,
            with no auxiliary information.
        init_prec: Initial precision matrix.
            If it is a float, it is defined as an identity matrix
            scaled by that float.

    Returns:
        GGN Laplace approximation transform instance.
    """
    init_fn = partial(init, init_prec=init_prec)
    update_fn = partial(
        update, forward=forward, outer_log_likelihood=outer_log_likelihood
    )
    return Transform(init_fn, update_fn)

posteriors.laplace.dense_ggn.DenseLaplaceState 𝞡

Bases: NamedTuple

State encoding a Normal distribution over parameters, with a dense precision matrix

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

prec Tensor

Precision matrix of the Normal distribution.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/laplace/dense_ggn.py
68
69
70
71
72
73
74
75
76
77
78
79
80
class DenseLaplaceState(NamedTuple):
    """State encoding a Normal distribution over parameters,
    with a dense precision matrix

    Attributes:
        params: Mean of the Normal distribution.
        prec: Precision matrix of the Normal distribution.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    prec: torch.Tensor
    aux: Any = None

posteriors.laplace.dense_ggn.init(params, init_prec=0.0) 𝞡

Initialise Normal distribution over parameters with a dense precision matrix.

Parameters:

Name Type Description Default
params TensorTree

Mean of the Normal distribution.

required
init_prec Tensor | float

Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float.

0.0

Returns:

Type Description
DenseLaplaceState

Initial DenseLaplaceState.

Source code in posteriors/laplace/dense_ggn.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def init(
    params: TensorTree,
    init_prec: torch.Tensor | float = 0.0,
) -> DenseLaplaceState:
    """Initialise Normal distribution over parameters
    with a dense precision matrix.

    Args:
        params: Mean of the Normal distribution.
        init_prec: Initial precision matrix.
            If it is a float, it is defined as an identity matrix
            scaled by that float.

    Returns:
        Initial DenseLaplaceState.
    """

    if is_scalar(init_prec):
        num_params = tree_size(params)
        init_prec = init_prec * torch.eye(num_params, requires_grad=False)

    return DenseLaplaceState(params, init_prec)

posteriors.laplace.dense_ggn.update(state, batch, forward, outer_log_likelihood, inplace=False) 𝞡

Adds GGN matrix over given batch.

Parameters:

Name Type Description Default
state DenseLaplaceState

Current state.

required
batch Any

Input data to model.

required
forward ForwardFn

Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information.

required
outer_log_likelihood OuterLogProbFn

A function that takes the output of forward and batch then returns the log likelihood of the model output, with no auxiliary information.

required
inplace bool

If True, then the state is updated in place, otherwise a new state is returned.

False

Returns:

Type Description
DenseLaplaceState

Updated DenseLaplaceState.

Source code in posteriors/laplace/dense_ggn.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def update(
    state: DenseLaplaceState,
    batch: Any,
    forward: ForwardFn,
    outer_log_likelihood: OuterLogProbFn,
    inplace: bool = False,
) -> DenseLaplaceState:
    """Adds GGN matrix over given batch.

    Args:
        state: Current state.
        batch: Input data to model.
        forward: Function that takes parameters and input batch and
            returns a forward value (e.g. logits), not reduced over the batch,
            as well as auxiliary information.
        outer_log_likelihood: A function that takes the output of `forward` and batch
            then returns the log likelihood of the model output,
            with no auxiliary information.
        inplace: If True, then the state is updated in place, otherwise a new state
            is returned.

    Returns:
        Updated DenseLaplaceState.
    """

    def outer_loss(z, batch):
        return -outer_log_likelihood(z, batch)

    with torch.no_grad(), CatchAuxError():
        ggn_batch, aux = ggn(
            lambda params: forward(params, batch),
            lambda z: outer_loss(z, batch),
            forward_has_aux=True,
            loss_has_aux=False,
            normalize=False,
        )(state.params)

    if inplace:
        state.prec.data += ggn_batch
        return state._replace(aux=aux)
    else:
        return DenseLaplaceState(state.params, state.prec + ggn_batch, aux)

posteriors.laplace.dense_ggn.sample(state, sample_shape=torch.Size([])) 𝞡

Sample from Normal distribution over parameters.

Parameters:

Name Type Description Default
state DenseLaplaceState

State encoding mean and precision matrix.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from the Normal distribution.

Source code in posteriors/laplace/dense_ggn.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def sample(
    state: DenseLaplaceState,
    sample_shape: torch.Size = torch.Size([]),
) -> TensorTree:
    """Sample from Normal distribution over parameters.

    Args:
        state: State encoding mean and precision matrix.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from the Normal distribution.
    """
    samples = torch.distributions.MultivariateNormal(
        loc=torch.zeros(state.prec.shape[0], device=state.prec.device),
        precision_matrix=state.prec,
        validate_args=False,
    ).sample(sample_shape)
    samples = samples.flatten(end_dim=-2)  # ensure samples is 2D
    mean_flat, unravel_func = tree_ravel(state.params)
    samples += mean_flat
    samples = torch.vmap(unravel_func)(samples)
    samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples)
    return samples