Skip to content

Laplace Dense Hessian𝞡

posteriors.laplace.dense_hessian.build(log_posterior, init_prec=0.0, epsilon=0.0, rescale=1.0) 𝞡

Builds a transform for dense Hessian Laplace.

Warning: The Hessian is not guaranteed to be positive definite, so setting epsilon > 0 ought to be considered.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
init_prec Tensor | float

Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float.

0.0
epsilon float

Added to the diagonal of the Hessian for numerical stability.

0.0
rescale float

Value to multiply the Hessian by (i.e. to normalize by batch size)

1.0

Returns:

Type Description
Transform

Hessian Laplace transform instance.

Source code in posteriors/laplace/dense_hessian.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def build(
    log_posterior: LogProbFn,
    init_prec: torch.Tensor | float = 0.0,
    epsilon: float = 0.0,
    rescale: float = 1.0,
) -> Transform:
    """Builds a transform for dense Hessian Laplace.

    **Warning:**
    The Hessian is not guaranteed to be positive definite,
    so setting epsilon > 0 ought to be considered.

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        init_prec: Initial precision matrix.
            If it is a float, it is defined as an identity matrix
            scaled by that float.
        epsilon: Added to the diagonal of the Hessian
            for numerical stability.
        rescale: Value to multiply the Hessian by
            (i.e. to normalize by batch size)

    Returns:
        Hessian Laplace transform instance.
    """
    init_fn = partial(init, init_prec=init_prec)
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        epsilon=epsilon,
        rescale=rescale,
    )
    return Transform(init_fn, update_fn)

posteriors.laplace.dense_hessian.DenseLaplaceState 𝞡

Bases: NamedTuple

State encoding a Normal distribution over parameters, with a dense precision matrix

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

prec Tensor

Precision matrix of the Normal distribution.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/laplace/dense_hessian.py
54
55
56
57
58
59
60
61
62
63
64
65
66
class DenseLaplaceState(NamedTuple):
    """State encoding a Normal distribution over parameters,
    with a dense precision matrix

    Attributes:
        params: Mean of the Normal distribution.
        prec: Precision matrix of the Normal distribution.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    prec: torch.Tensor
    aux: Any = None

posteriors.laplace.dense_hessian.init(params, init_prec=0.0) 𝞡

Initialise Normal distribution over parameters with a dense precision matrix.

Parameters:

Name Type Description Default
params TensorTree

Mean of the Normal distribution.

required
init_prec Tensor | float

Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float.

0.0

Returns:

Type Description
DenseLaplaceState

Initial DenseLaplaceState.

Source code in posteriors/laplace/dense_hessian.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def init(
    params: TensorTree,
    init_prec: torch.Tensor | float = 0.0,
) -> DenseLaplaceState:
    """Initialise Normal distribution over parameters
    with a dense precision matrix.

    Args:
        params: Mean of the Normal distribution.
        init_prec: Initial precision matrix.
            If it is a float, it is defined as an identity matrix
            scaled by that float.

    Returns:
        Initial DenseLaplaceState.
    """

    if is_scalar(init_prec):
        num_params = tree_size(params)
        init_prec = init_prec * torch.eye(num_params, requires_grad=False)

    return DenseLaplaceState(params, init_prec)

posteriors.laplace.dense_hessian.update(state, batch, log_posterior, epsilon=0.0, rescale=1.0, inplace=False) 𝞡

Adds the Hessian of the negative log-posterior over given batch.

Warning: The Hessian is not guaranteed to be positive definite, so setting epsilon > 0 ought to be considered.

Parameters:

Name Type Description Default
state DenseLaplaceState

Current state.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised)

required
epsilon float

Added to the diagonal of the Hessian for numerical stability.

0.0
rescale float

Value to multiply the Hessian by (i.e. to normalize by batch size)

1.0
inplace bool

If True, the state is updated in place. Otherwise, a new state is returned.

False

Returns:

Type Description
DenseLaplaceState

Updated DenseLaplaceState.

Source code in posteriors/laplace/dense_hessian.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def update(
    state: DenseLaplaceState,
    batch: Any,
    log_posterior: LogProbFn,
    epsilon: float = 0.0,
    rescale: float = 1.0,
    inplace: bool = False,
) -> DenseLaplaceState:
    """Adds the Hessian of the negative log-posterior over given batch.

    **Warning:**
    The Hessian is not guaranteed to be positive definite,
    so setting epsilon > 0 ought to be considered.

    Args:
        state: Current state.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
        epsilon: Added to the diagonal of the Hessian
            for numerical stability.
        rescale: Value to multiply the Hessian by
            (i.e. to normalize by batch size)
        inplace: If True, the state is updated in place. Otherwise, a new
            state is returned.

    Returns:
        Updated DenseLaplaceState.
    """
    with torch.no_grad(), CatchAuxError():
        flat_params, params_unravel = tree_ravel(state.params)
        num_params = flat_params.numel()

        def neg_log_p(p_flat):
            value, aux = log_posterior(params_unravel(p_flat), batch)
            return -value, aux

        hess, aux = jacfwd(jacrev(neg_log_p, has_aux=True), has_aux=True)(flat_params)
        hess = hess * rescale + epsilon * torch.eye(num_params)

    if inplace:
        state.prec.data += hess
        return state._replace(aux=aux)
    else:
        return DenseLaplaceState(state.params, state.prec + hess, aux)

posteriors.laplace.dense_hessian.sample(state, sample_shape=torch.Size([])) 𝞡

Sample from Normal distribution over parameters.

Parameters:

Name Type Description Default
state DenseLaplaceState

State encoding mean and precision matrix.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from the Normal distribution.

Source code in posteriors/laplace/dense_hessian.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def sample(
    state: DenseLaplaceState,
    sample_shape: torch.Size = torch.Size([]),
) -> TensorTree:
    """Sample from Normal distribution over parameters.

    Args:
        state: State encoding mean and precision matrix.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from the Normal distribution.
    """
    samples = torch.distributions.MultivariateNormal(
        loc=torch.zeros(state.prec.shape[0], device=state.prec.device),
        precision_matrix=state.prec,
        validate_args=False,
    ).sample(sample_shape)
    samples = samples.flatten(end_dim=-2)  # ensure samples is 2D
    mean_flat, unravel_func = tree_ravel(state.params)
    samples += mean_flat
    samples = torch.vmap(unravel_func)(samples)
    samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples)
    return samples