Skip to content

Laplace Diagonal GGN𝞡

posteriors.laplace.diag_ggn.build(forward, outer_log_likelihood, init_prec_diag=0.0) 𝞡

Builds a transform for a diagonal Generalized Gauss-Newton (GGN) Laplace approximation.

Equivalent to the diagonal of the (non-empirical) Fisher information matrix when the outer_log_likelihood is exponential family with natural parameter equal to the output from forward.

forward should output auxiliary information (or torch.tensor([])), outer_log_likelihood should not.

The GGN is defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss (negative log-likelihood) that maps the output of \(f\) to a scalar output.

More info on Fisher and GGN matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.

Parameters:

Name Type Description Default
forward ForwardFn

Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information.

required
outer_log_likelihood OuterLogProbFn

A function that takes the output of forward and batch then returns the log likelihood of the model output, with no auxiliary information.

required
init_prec_diag TensorTree | float

Initial diagonal precision matrix. Can be tree like params or scalar.

0.0

Returns:

Type Description
Transform

Diagonal GGN Laplace approximation transform instance.

Source code in posteriors/laplace/diag_ggn.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def build(
    forward: ForwardFn,
    outer_log_likelihood: OuterLogProbFn,
    init_prec_diag: TensorTree | float = 0.0,
) -> Transform:
    """Builds a transform for a diagonal Generalized Gauss-Newton (GGN)
    Laplace approximation.

    Equivalent to the diagonal of the (non-empirical) Fisher information matrix when
    the `outer_log_likelihood` is exponential family with natural parameter equal to
    the output from `forward`.

    `forward` should output auxiliary information (or `torch.tensor([])`),
    `outer_log_likelihood` should not.

    The GGN is defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
    is a loss (negative log-likelihood) that maps the output of $f$ to a scalar output.

    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and
    their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).

    Args:
        forward: Function that takes parameters and input batch and
            returns a forward value (e.g. logits), not reduced over the batch,
            as well as auxiliary information.
        outer_log_likelihood: A function that takes the output of `forward` and batch
            then returns the log likelihood of the model output,
            with no auxiliary information.
        init_prec_diag: Initial diagonal precision matrix.
            Can be tree like params or scalar.

    Returns:
        Diagonal GGN Laplace approximation transform instance.
    """
    init_fn = partial(init, init_prec_diag=init_prec_diag)
    update_fn = partial(
        update, forward=forward, outer_log_likelihood=outer_log_likelihood
    )
    return Transform(init_fn, update_fn)

posteriors.laplace.diag_ggn.DiagLaplaceState 𝞡

Bases: NamedTuple

State encoding a diagonal Normal distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

prec_diag TensorTree

Diagonal of the precision matrix of the Normal distribution.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/laplace/diag_ggn.py
67
68
69
70
71
72
73
74
75
76
77
78
class DiagLaplaceState(NamedTuple):
    """State encoding a diagonal Normal distribution over parameters.

    Attributes:
        params: Mean of the Normal distribution.
        prec_diag: Diagonal of the precision matrix of the Normal distribution.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    prec_diag: TensorTree
    aux: Any = None

posteriors.laplace.diag_ggn.init(params, init_prec_diag=0.0) 𝞡

Initialise diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
params TensorTree

Mean of the Normal distribution.

required
init_prec_diag TensorTree | float

Initial diagonal precision matrix. Can be tree like params or scalar.

0.0

Returns:

Type Description
DiagLaplaceState

Initial DiagLaplaceState.

Source code in posteriors/laplace/diag_ggn.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def init(
    params: TensorTree,
    init_prec_diag: TensorTree | float = 0.0,
) -> DiagLaplaceState:
    """Initialise diagonal Normal distribution over parameters.

    Args:
        params: Mean of the Normal distribution.
        init_prec_diag: Initial diagonal precision matrix.
            Can be tree like params or scalar.

    Returns:
        Initial DiagLaplaceState.
    """
    if is_scalar(init_prec_diag):
        init_prec_diag = tree_map(
            lambda x: torch.full_like(x, init_prec_diag, requires_grad=x.requires_grad),
            params,
        )

    return DiagLaplaceState(params, init_prec_diag)

posteriors.laplace.diag_ggn.update(state, batch, forward, outer_log_likelihood, inplace=False) 𝞡

Adds diagonal GGN matrix of covariance summed over given batch.

Parameters:

Name Type Description Default
state DiagLaplaceState

Current state.

required
batch Any

Input data to model.

required
forward ForwardFn

Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information.

required
outer_log_likelihood OuterLogProbFn

A function that takes the output of forward and batch then returns the log likelihood of the model output, with no auxiliary information.

required
inplace bool

If True, then the state is updated in place, otherwise a new state is returned.

False

Returns:

Type Description
DiagLaplaceState

Updated DiagLaplaceState.

Source code in posteriors/laplace/diag_ggn.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def update(
    state: DiagLaplaceState,
    batch: Any,
    forward: ForwardFn,
    outer_log_likelihood: OuterLogProbFn,
    inplace: bool = False,
) -> DiagLaplaceState:
    """Adds diagonal GGN matrix of covariance summed over given batch.

    Args:
        state: Current state.
        batch: Input data to model.
        forward: Function that takes parameters and input batch and
            returns a forward value (e.g. logits), not reduced over the batch,
            as well as auxiliary information.
        outer_log_likelihood: A function that takes the output of `forward` and batch
            then returns the log likelihood of the model output,
            with no auxiliary information.
        inplace: If True, then the state is updated in place, otherwise a new state
            is returned.

    Returns:
        Updated DiagLaplaceState.
    """

    def outer_loss(z, batch):
        return -outer_log_likelihood(z, batch)

    with torch.no_grad(), CatchAuxError():
        diag_ggn_batch, aux = diag_ggn(
            lambda params: forward(params, batch),
            lambda z: outer_loss(z, batch),
            forward_has_aux=True,
            loss_has_aux=False,
            normalize=False,
        )(state.params)

    def update_func(x, y):
        return x + y

    prec_diag = flexi_tree_map(
        update_func, state.prec_diag, diag_ggn_batch, inplace=inplace
    )

    if inplace:
        return state._replace(aux=aux)
    return DiagLaplaceState(state.params, prec_diag, aux)

posteriors.laplace.diag_ggn.sample(state, sample_shape=torch.Size([])) 𝞡

Sample from diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
state DiagLaplaceState

State encoding mean and diagonal precision.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Normal distribution.

Source code in posteriors/laplace/diag_ggn.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def sample(
    state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
    """Sample from diagonal Normal distribution over parameters.

    Args:
        state: State encoding mean and diagonal precision.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Normal distribution.
    """
    sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag)
    return diag_normal_sample(state.params, sd_diag, sample_shape=sample_shape)