Skip to content

Laplace Diagonal Fisher𝞡

posteriors.laplace.diag_fisher.build(log_posterior, per_sample=False, init_prec_diag=0.0) 𝞡

Builds a transform for diagonal empirical Fisher information Laplace approximation.

The empirical Fisher is defined here as: $$ F(θ) = \sum_i ∇_θ \log p(y_i, θ | x_i) ∇_θ \log p(y_i, θ | x_i)^T $$ where \(p(y_i, θ | x_i)\) is the joint model distribution (equivalent to the posterior up to proportionality) with parameters \(θ\), inputs \(x_i\) and labels \(y_i\).

More info on empirical Fisher matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
per_sample bool

If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample.

False
init_prec_diag TensorTree | float

Initial diagonal precision matrix. Can be tree like params or scalar.

0.0

Returns:

Type Description
Transform

Diagonal empirical Fisher information Laplace approximation transform instance.

Source code in posteriors/laplace/diag_fisher.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def build(
    log_posterior: LogProbFn,
    per_sample: bool = False,
    init_prec_diag: TensorTree | float = 0.0,
) -> Transform:
    """Builds a transform for diagonal empirical Fisher information
    Laplace approximation.

    The empirical Fisher is defined here as:
    $$
    F(θ) = \\sum_i ∇_θ \\log p(y_i, θ | x_i) ∇_θ \\log p(y_i, θ | x_i)^T
    $$
    where $p(y_i, θ | x_i)$ is the joint model distribution (equivalent to the posterior
    up to proportionality) with parameters $θ$, inputs $x_i$ and labels $y_i$.

    More info on empirical Fisher matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and
    their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        per_sample: If True, then log_posterior is assumed to return a vector of
            log posteriors for each sample in the batch. If False, then log_posterior
            is assumed to return a scalar log posterior for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_posterior to be per sample.
        init_prec_diag: Initial diagonal precision matrix.
            Can be tree like params or scalar.

    Returns:
        Diagonal empirical Fisher information Laplace approximation transform instance.
    """
    init_fn = partial(init, init_prec_diag=init_prec_diag)
    update_fn = partial(update, log_posterior=log_posterior, per_sample=per_sample)
    return Transform(init_fn, update_fn)

posteriors.laplace.diag_fisher.DiagLaplaceState 𝞡

Bases: NamedTuple

State encoding a diagonal Normal distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

prec_diag TensorTree

Diagonal of the precision matrix of the Normal distribution.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/laplace/diag_fisher.py
56
57
58
59
60
61
62
63
64
65
66
67
class DiagLaplaceState(NamedTuple):
    """State encoding a diagonal Normal distribution over parameters.

    Attributes:
        params: Mean of the Normal distribution.
        prec_diag: Diagonal of the precision matrix of the Normal distribution.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    prec_diag: TensorTree
    aux: Any = None

posteriors.laplace.diag_fisher.init(params, init_prec_diag=0.0) 𝞡

Initialise diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
params TensorTree

Mean of the Normal distribution.

required
init_prec_diag TensorTree | float

Initial diagonal precision matrix. Can be tree like params or scalar.

0.0

Returns:

Type Description
DiagLaplaceState

Initial DiagLaplaceState.

Source code in posteriors/laplace/diag_fisher.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def init(
    params: TensorTree,
    init_prec_diag: TensorTree | float = 0.0,
) -> DiagLaplaceState:
    """Initialise diagonal Normal distribution over parameters.

    Args:
        params: Mean of the Normal distribution.
        init_prec_diag: Initial diagonal precision matrix.
            Can be tree like params or scalar.

    Returns:
        Initial DiagLaplaceState.
    """
    if is_scalar(init_prec_diag):
        init_prec_diag = tree_map(
            lambda x: torch.full_like(x, init_prec_diag, requires_grad=x.requires_grad),
            params,
        )

    return DiagLaplaceState(params, init_prec_diag)

posteriors.laplace.diag_fisher.update(state, batch, log_posterior, per_sample=False, inplace=False) 𝞡

Adds diagonal empirical Fisher information matrix of covariance summed over given batch.

Parameters:

Name Type Description Default
state DiagLaplaceState

Current state.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
per_sample bool

If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample.

False
inplace bool

If True, then the state is updated in place, otherwise a new state is returned.

False

Returns:

Type Description
DiagLaplaceState

Updated DiagLaplaceState.

Source code in posteriors/laplace/diag_fisher.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def update(
    state: DiagLaplaceState,
    batch: Any,
    log_posterior: LogProbFn,
    per_sample: bool = False,
    inplace: bool = False,
) -> DiagLaplaceState:
    """Adds diagonal empirical Fisher information matrix of covariance summed over
    given batch.

    Args:
        state: Current state.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        per_sample: If True, then log_posterior is assumed to return a vector of
            log posteriors for each sample in the batch. If False, then log_posterior
            is assumed to return a scalar log posterior for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_posterior to be per sample.
        inplace: If True, then the state is updated in place, otherwise a new state
            is returned.

    Returns:
        Updated DiagLaplaceState.
    """
    if not per_sample:
        log_posterior = per_samplify(log_posterior)

    with torch.no_grad(), CatchAuxError():
        jac, aux = jacrev(log_posterior, has_aux=True)(state.params, batch)
        batch_diag_score_sq = tree_map(lambda j: j.square().sum(0), jac)

    def update_func(x, y):
        return x + y

    prec_diag = flexi_tree_map(
        update_func, state.prec_diag, batch_diag_score_sq, inplace=inplace
    )

    if inplace:
        return state._replace(aux=aux)
    return DiagLaplaceState(state.params, prec_diag, aux)

posteriors.laplace.diag_fisher.sample(state, sample_shape=torch.Size([])) 𝞡

Sample from diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
state DiagLaplaceState

State encoding mean and diagonal precision.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Normal distribution.

Source code in posteriors/laplace/diag_fisher.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def sample(
    state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
    """Sample from diagonal Normal distribution over parameters.

    Args:
        state: State encoding mean and diagonal precision.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Normal distribution.
    """
    sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag)
    return diag_normal_sample(state.params, sd_diag, sample_shape=sample_shape)