Skip to content

Utils

posteriors.utils.CatchAuxError 𝞡

Bases: AbstractContextManager

Context manager to catch errors when auxiliary output is not found.

Source code in posteriors/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class CatchAuxError(contextlib.AbstractContextManager):
    """Context manager to catch errors when auxiliary output is not found."""

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is not None:
            if NO_AUX_ERROR_MSG in str(exc_value):
                raise RuntimeError(
                    "Auxiliary output not found. Perhaps you have forgotten to return "
                    "the aux output?\n"
                    "\tIf you don't have any auxiliary info, simply amend to e.g. "
                    "log_posterior(params, batch) -> Tuple[float, torch.tensor([])].\n"
                    "\tMore info at https://normal-computing.github.io/posteriors/log_posteriors"
                )
            elif NON_TENSOR_AUX_ERROR_MSG in str(exc_value):
                raise RuntimeError(
                    "Auxiliary output should be a TensorTree. If you don't have any "
                    "auxiliary info, simply amend to e.g. "
                    "log_posterior(params, batch) -> Tuple[float, torch.tensor([])].\n"
                    "\tMore info at https://normal-computing.github.io/posteriors/log_posteriors"
                )
        return False

posteriors.utils.model_to_function(model) 𝞡

Converts a model into a function that maps parameters and inputs to outputs.

Convenience wrapper around torch.functional_call.

Parameters:

Name Type Description Default
model Module

torch.nn.Module with parameters stored in .named_parameters().

required

Returns:

Type Description
Callable[[TensorTree, Any], Any]

Function that takes a PyTree of parameters as well as any input arg or kwargs and returns the output of the model.

Source code in posteriors/utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def model_to_function(model: torch.nn.Module) -> Callable[[TensorTree, Any], Any]:
    """Converts a model into a function that maps parameters and inputs to outputs.

    Convenience wrapper around [torch.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html).

    Args:
        model: torch.nn.Module with parameters stored in .named_parameters().

    Returns:
        Function that takes a PyTree of parameters as well as any input
            arg or kwargs and returns the output of the model.
    """

    def func_model(p_dict, *args, **kwargs):
        return functional_call(model, p_dict, args=args, kwargs=kwargs)

    return func_model

posteriors.utils.linearized_forward_diag(forward_func, params, batch, sd_diag) 𝞡

Compute the linearized forward mean and its square root covariance, assuming posterior covariance over parameters is diagonal.

$$ f(x | θ) \sim N(x | f(x | θₘ), J(x | θₘ) \Sigma J(x | θₘ)^T) $$ where \(θₘ\) is the MAP estimate, \(\Sigma\) is the diagonal covariance approximation at the MAP and \(J(x | θₘ)\) is the Jacobian of the forward function \(f(x | θₘ)\) with respect to \(θₘ\).

For more info on linearized models see Foong et al, 2019.

Parameters:

Name Type Description Default
forward_func ForwardFn

A function that takes params and batch and returns the forward values and any auxiliary information. Forward values must be a dim=2 Tensor with batch dimension in its first axis.

required
params TensorTree

PyTree of tensors.

required
batch TensorTree

PyTree of tensors.

required
sd_diag TensorTree

PyTree of tensors of same shape as params.

required

Returns:

Type Description
Tuple[TensorTree, Tensor, TensorTree]

A tuple of (forward_vals, chol, aux) where forward_vals is the output of the forward function (mean), chol is the tensor square root of the covariance matrix (non-diagonal) and aux is auxiliary info from the forward function.

Source code in posteriors/utils.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def linearized_forward_diag(
    forward_func: ForwardFn, params: TensorTree, batch: TensorTree, sd_diag: TensorTree
) -> Tuple[TensorTree, Tensor, TensorTree]:
    """Compute the linearized forward mean and its square root covariance, assuming
    posterior covariance over parameters is diagonal.

    $$
    f(x | θ) \\sim N(x | f(x | θₘ), J(x | θₘ) \\Sigma J(x | θₘ)^T)
    $$
    where $θₘ$ is the MAP estimate, $\\Sigma$ is the diagonal covariance approximation
    at the MAP and $J(x | θₘ)$ is the Jacobian of the forward function $f(x | θₘ)$ with
    respect to $θₘ$.

    For more info on linearized models see [Foong et al, 2019](https://arxiv.org/abs/1906.11537).

    Args:
        forward_func: A function that takes params and batch and returns the forward
            values and any auxiliary information. Forward values must be a dim=2 Tensor
            with batch dimension in its first axis.
        params: PyTree of tensors.
        batch: PyTree of tensors.
        sd_diag: PyTree of tensors of same shape as params.

    Returns:
        A tuple of (forward_vals, chol, aux) where forward_vals is the output of the
            forward function (mean), chol is the tensor square root of the covariance
            matrix (non-diagonal) and aux is auxiliary info from the forward function.
    """
    forward_vals, aux = forward_func(params, batch)

    with torch.no_grad(), CatchAuxError():
        jac, _ = jacrev(forward_func, has_aux=True)(params, batch)

    # Convert Jacobian to be flat in parameter dimension
    jac = tree_flatten(jac)[0]
    jac = torch.cat([x.flatten(start_dim=2) for x in jac], dim=2)

    # Flatten the diagonal square root covariance
    sd_diag = tree_flatten(sd_diag)[0]
    sd_diag = torch.cat([x.flatten() for x in sd_diag])

    # Cholesky of J @ Σ @ J^T
    linearised_chol = torch.linalg.cholesky((jac * sd_diag**2) @ jac.transpose(-1, -2))

    return forward_vals, linearised_chol, aux

posteriors.utils.hvp(f, primals, tangents, has_aux=False) 𝞡

Hessian vector product.

H(primals) @ tangents

where H(primals) is the Hessian of f evaluated at primals.

Taken from jacobians_hessians.html. Follows API from torch.func.jvp.

Parameters:

Name Type Description Default
f Callable

A function with scalar output.

required
primals tuple

Tuple of e.g. tensor or dict with tensor values to evalute f at.

required
tangents tuple

Tuple matching structure of primals.

required
has_aux bool

Whether f returns auxiliary information.

False

Returns:

Type Description
Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]

Returns a (gradient, hvp_out) tuple containing the gradient of func evaluated at primals and the Hessian-vector product. If has_aux is True, then instead returns a (gradient, hvp_out, aux) tuple.

Source code in posteriors/utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def hvp(
    f: Callable, primals: tuple, tangents: tuple, has_aux: bool = False
) -> Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]:
    """Hessian vector product.

    H(primals) @ tangents

    where H(primals) is the Hessian of f evaluated at primals.

    Taken from [jacobians_hessians.html](https://pytorch.org/functorch/nightly/notebooks/jacobians_hessians.html).
    Follows API from [`torch.func.jvp`](https://pytorch.org/docs/stable/generated/torch.func.jvp.html).

    Args:
        f: A function with scalar output.
        primals: Tuple of e.g. tensor or dict with tensor values to evalute f at.
        tangents: Tuple matching structure of primals.
        has_aux: Whether f returns auxiliary information.

    Returns:
        Returns a (gradient, hvp_out) tuple containing the gradient of func evaluated at
            primals and the Hessian-vector product. If has_aux is True, then instead
            returns a (gradient, hvp_out, aux) tuple.
    """
    return jvp(grad(f, has_aux=has_aux), primals, tangents, has_aux=has_aux)

posteriors.utils.fvp(f, primals, tangents, has_aux=False, normalize=False) 𝞡

Empirical Fisher vector product.

F(primals) @ tangents

where F(primals) is the empirical Fisher of f evaluated at primals.

The empirical Fisher is defined as: $$ F(θ) = J_f(θ) J_f(θ)^T $$ where typically \(f_θ\) is the per-sample log likelihood (with elements \(\log p(y_i | x_i, θ)\) for a model with primals \(θ\) given inputs \(x_i\) and labels \(y_i\)).

If normalize=True, then \(F(θ)\) is divided by the number of outputs from f (i.e. batchsize).

Follows API from torch.func.jvp.

More info on empirical Fisher matrices can be found in Martens, 2020.

Examples:

from functools import partial
from optree import tree_map
import torch
from posteriors import fvp

# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}

def log_likelihood_per_sample(params, batch):
    output = torch.func.functional_call(model, params, batch["inputs"])
    return -torch.nn.functional.cross_entropy(
        output, batch["labels"], reduction="none"
    )

params = dict(model.parameters())
v = tree_map(lambda x: torch.randn_like(x), params)
fvp_result = fvp(
    partial(log_likelihood_per_sample, batch=batch),
    (params,),
    (v,)
)

Parameters:

Name Type Description Default
f Callable

A function with tensor output. Typically this is the per-sample log likelihood of a model.

required
primals tuple

Tuple of e.g. tensor or dict with tensor values to evaluate f at.

required
tangents tuple

Tuple matching structure of primals.

required
has_aux bool

Whether f returns auxiliary information.

False
normalize bool

Whether to normalize, divide by the dimension of the output from f.

False

Returns:

Type Description
Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]

Returns a (output, fvp_out) tuple containing the output of func evaluated at primals and the empirical Fisher-vector product. If has_aux is True, then instead returns a (output, fvp_out, aux) tuple.

Source code in posteriors/utils.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def fvp(
    f: Callable,
    primals: tuple,
    tangents: tuple,
    has_aux: bool = False,
    normalize: bool = False,
) -> Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]:
    """Empirical Fisher vector product.

    F(primals) @ tangents

    where F(primals) is the empirical Fisher of f evaluated at primals.

    The empirical Fisher is defined as:
    $$
    F(θ) = J_f(θ) J_f(θ)^T
    $$
    where typically $f_θ$ is the per-sample log likelihood (with elements
    $\\log p(y_i | x_i, θ)$ for a model with `primals` $θ$ given inputs $x_i$ and
    labels $y_i$).

    If `normalize=True`, then $F(θ)$ is divided by the number of outputs from f
    (i.e. batchsize).

    Follows API from [`torch.func.jvp`](https://pytorch.org/docs/stable/generated/torch.func.jvp.html).

    More info on empirical Fisher matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Examples:
        ```python
        from functools import partial
        from optree import tree_map
        import torch
        from posteriors import fvp

        # Load model that outputs logits
        # Load batch = {'inputs': ..., 'labels': ...}

        def log_likelihood_per_sample(params, batch):
            output = torch.func.functional_call(model, params, batch["inputs"])
            return -torch.nn.functional.cross_entropy(
                output, batch["labels"], reduction="none"
            )

        params = dict(model.parameters())
        v = tree_map(lambda x: torch.randn_like(x), params)
        fvp_result = fvp(
            partial(log_likelihood_per_sample, batch=batch),
            (params,),
            (v,)
        )
        ```

    Args:
        f: A function with tensor output.
            Typically this is the [per-sample log likelihood of a model](https://pytorch.org/tutorials/intermediate/per_sample_grads.html).
        primals: Tuple of e.g. tensor or dict with tensor values to evaluate f at.
        tangents: Tuple matching structure of primals.
        has_aux: Whether f returns auxiliary information.
        normalize: Whether to normalize, divide by the dimension of the output from f.

    Returns:
        Returns a (output, fvp_out) tuple containing the output of func evaluated at
            primals and the empirical Fisher-vector product. If has_aux is True, then
            instead returns a (output, fvp_out, aux) tuple.
    """
    jvp_output = jvp(f, primals, tangents, has_aux=has_aux)
    Jv = jvp_output[1]
    f_vjp = vjp(f, *primals, has_aux=has_aux)[1]
    Fv = f_vjp(Jv)[0]

    if normalize:
        output_dim = tree_flatten(jvp_output[0])[0][0].shape[0]
        Fv = tree_map(lambda x: x / output_dim, Fv)

    return jvp_output[0], Fv, *jvp_output[2:]

posteriors.utils.empirical_fisher(f, argnums=0, has_aux=False, normalize=False) 𝞡

Constructs function to compute the empirical Fisher information matrix of a function f with respect to its parameters, defined as (unnormalized): $$ F(θ) = J_f(θ) J_f(θ)^T $$ where typically \(f_θ\) is the per-sample log likelihood (with elements \(\log p(y_i | x_i, θ)\) for a model with primals \(θ\) given inputs \(x_i\) and labels \(y_i\)).

If normalize=True, then \(F(θ)\) is divided by the number of outputs from f (i.e. batchsize).

The empirical Fisher will be provided as a square tensor with respect to the ravelled parameters. flat_params, params_unravel = optree.tree_ravel(params).

Follows API from torch.func.jacrev.

More info on empirical Fisher matrices can be found in Martens, 2020.

Examples:

import torch
from posteriors import empirical_fisher, per_samplify

# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}

def log_likelihood(params, batch):
    output = torch.func.functional_call(model, params, batch['inputs'])
    return -torch.nn.functional.cross_entropy(output, batch['labels'])

likelihood_per_sample = per_samplify(log_likelihood)
params = dict(model.parameters())
ef_result = empirical_fisher(log_likelihood_per_sample)(params, batch)

Parameters:

Name Type Description Default
f Callable

A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors. Typically this is the per-sample log likelihood of a model.

required
argnums int | Sequence[int]

Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to.

0
has_aux bool

Whether f returns auxiliary information.

False
normalize bool

Whether to normalize, divide by the dimension of the output from f.

False

Returns:

Type Description
Callable

A function with the same arguments as f that returns the empirical Fisher, F. If has_aux is True, then the function instead returns a tuple of (F, aux).

Source code in posteriors/utils.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def empirical_fisher(
    f: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    normalize: bool = False,
) -> Callable:
    """
    Constructs function to compute the empirical Fisher information matrix of a function
    f with respect to its parameters, defined as (unnormalized):
    $$
    F(θ) = J_f(θ) J_f(θ)^T
    $$
    where typically $f_θ$ is the per-sample log likelihood (with elements
    $\\log p(y_i | x_i, θ)$ for a model with `primals` $θ$ given inputs $x_i$ and
    labels $y_i$).

    If `normalize=True`, then $F(θ)$ is divided by the number of outputs from f
    (i.e. batchsize).

    The empirical Fisher will be provided as a square tensor with respect to the
    ravelled parameters.
    `flat_params, params_unravel = optree.tree_ravel(params)`.

    Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).

    More info on empirical Fisher matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Examples:
        ```python
        import torch
        from posteriors import empirical_fisher, per_samplify

        # Load model that outputs logits
        # Load batch = {'inputs': ..., 'labels': ...}

        def log_likelihood(params, batch):
            output = torch.func.functional_call(model, params, batch['inputs'])
            return -torch.nn.functional.cross_entropy(output, batch['labels'])

        likelihood_per_sample = per_samplify(log_likelihood)
        params = dict(model.parameters())
        ef_result = empirical_fisher(log_likelihood_per_sample)(params, batch)
        ```

    Args:
        f:  A Python function that takes one or more arguments, one of which must be a
            Tensor, and returns one or more Tensors.
            Typically this is the [per-sample log likelihood of a model](https://pytorch.org/tutorials/intermediate/per_sample_grads.html).
        argnums: Optional, integer or sequence of integers. Specifies which
            positional argument(s) to differentiate with respect to.
        has_aux: Whether f returns auxiliary information.
        normalize: Whether to normalize, divide by the dimension of the output from f.

    Returns:
        A function with the same arguments as f that returns the empirical Fisher, F.
            If has_aux is True, then the function instead returns a tuple of (F, aux).
    """

    def f_to_flat(*args, **kwargs):
        f_out = f(*args, **kwargs)
        f_out_val = f_out[0] if has_aux else f_out
        f_out_val = tree_ravel(f_out_val)[0]
        return (f_out_val, f_out[1]) if has_aux else f_out_val

    def fisher(*args, **kwargs):
        jac_output = jacrev(f_to_flat, argnums=argnums, has_aux=has_aux)(
            *args, **kwargs
        )
        jac = jac_output[0] if has_aux else jac_output

        # Convert Jacobian to tensor, flat in parameter dimension
        jac = torch.vmap(lambda x: tree_ravel(x)[0])(jac)

        rescale = 1 / jac.shape[0] if normalize else 1

        if has_aux:
            return jac.T @ jac * rescale, jac_output[1]
        else:
            return jac.T @ jac * rescale

    return fisher

posteriors.utils.ggnvp(forward, loss, primals, tangents, forward_has_aux=False, loss_has_aux=False, normalize=False) 𝞡

Generalised Gauss-Newton vector product.

Equivalent to the (non-empirical) Fisher vector product when loss is the negative log likelihood of an exponential family distribution as a function of its natural parameter.

Defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss function with scalar output.

Thus \(J_f(θ)\) is the Jacobian of the forward function \(f\) evaluated at primals \(θ\), with dimensions (dz, dθ). And \(H_l(z)\) is the Hessian of the loss function \(l\) evaluated at z = f(θ), with dimensions (dz, dz).

Follows API from torch.func.jvp.

More info on Fisher and GGN matrices can be found in Martens, 2020.

Examples:

from functools import partial
from optree import tree_map
import torch
from posteriors import ggnvp

# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}

def forward(params, inputs):
    return torch.func.functional_call(model, params, inputs)

def loss(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels)

params = dict(model.parameters())
v = tree_map(lambda x: torch.randn_like(x), params)
ggnvp_result = ggnvp(
    partial(forward, inputs=batch['inputs']),
    partial(loss, labels=batch['labels']),
    (params,),
    (v,),
)

Parameters:

Name Type Description Default
forward Callable

A function with tensor output.

required
loss Callable

A function that maps the output of forward to a scalar output.

required
primals tuple

Tuple of e.g. tensor or dict with tensor values to evaluate f at.

required
tangents tuple

Tuple matching structure of primals.

required
forward_has_aux bool

Whether forward returns auxiliary information.

False
loss_has_aux bool

Whether loss returns auxiliary information.

False
normalize bool

Whether to normalize, divide by the first dimension of the output from f.

False

Returns:

Type Description
Tuple[float, TensorTree] | Tuple[float, TensorTree, Any] | Tuple[float, TensorTree, Any, Any]

Returns a (output, ggnvp_out) tuple, where output is a tuple of (forward(primals), grad(loss)(forward(primals))). If forward_has_aux or loss_has_aux is True, then instead returns a (output, ggnvp_out, aux) or (output, ggnvp_out, forward_aux, loss_aux) tuple accordingly.

Source code in posteriors/utils.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def ggnvp(
    forward: Callable,
    loss: Callable,
    primals: tuple,
    tangents: tuple,
    forward_has_aux: bool = False,
    loss_has_aux: bool = False,
    normalize: bool = False,
) -> (
    Tuple[float, TensorTree]
    | Tuple[float, TensorTree, Any]
    | Tuple[float, TensorTree, Any, Any]
):
    """Generalised Gauss-Newton vector product.

    Equivalent to the (non-empirical) Fisher vector product when `loss` is the negative
    log likelihood of an exponential family distribution as a function of its natural
    parameter.

    Defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
    is a loss function with scalar output.

    Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
    at `primals` $θ$, with dimensions `(dz, dθ)`.
    And $H_l(z)$ is the Hessian of the loss function $l$ evaluated at `z = f(θ)`, with
    dimensions `(dz, dz)`.

    Follows API from [`torch.func.jvp`](https://pytorch.org/docs/stable/generated/torch.func.jvp.html).

    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Examples:
        ```python
        from functools import partial
        from optree import tree_map
        import torch
        from posteriors import ggnvp

        # Load model that outputs logits
        # Load batch = {'inputs': ..., 'labels': ...}

        def forward(params, inputs):
            return torch.func.functional_call(model, params, inputs)

        def loss(logits, labels):
            return torch.nn.functional.cross_entropy(logits, labels)

        params = dict(model.parameters())
        v = tree_map(lambda x: torch.randn_like(x), params)
        ggnvp_result = ggnvp(
            partial(forward, inputs=batch['inputs']),
            partial(loss, labels=batch['labels']),
            (params,),
            (v,),
        )
        ```

    Args:
        forward: A function with tensor output.
        loss: A function that maps the output of forward to a scalar output.
        primals: Tuple of e.g. tensor or dict with tensor values to evaluate f at.
        tangents: Tuple matching structure of primals.
        forward_has_aux: Whether forward returns auxiliary information.
        loss_has_aux: Whether loss returns auxiliary information.
        normalize: Whether to normalize, divide by the first dimension of the output
            from f.

    Returns:
        Returns a (output, ggnvp_out) tuple, where output is a tuple of
            `(forward(primals), grad(loss)(forward(primals)))`.
            If forward_has_aux or loss_has_aux is True, then instead returns a
            (output, ggnvp_out, aux) or
            (output, ggnvp_out, forward_aux, loss_aux) tuple accordingly.
    """

    jvp_output = jvp(forward, primals, tangents, has_aux=forward_has_aux)
    z = jvp_output[0]
    Jv = jvp_output[1]
    HJv_output = hvp(loss, (z,), (Jv,), has_aux=loss_has_aux)
    HJv = HJv_output[1]

    if normalize:
        output_dim = tree_flatten(jvp_output[0])[0][0].shape[0]
        HJv = tree_map(lambda x: x / output_dim, HJv)

    forward_vjp = vjp(forward, *primals, has_aux=forward_has_aux)[1]
    JTHJv = forward_vjp(HJv)[0]

    return (jvp_output[0], HJv_output[0]), JTHJv, *jvp_output[2:], *HJv_output[2:]

posteriors.utils.ggn(forward, loss, argnums=0, forward_has_aux=False, loss_has_aux=False, normalize=False) 𝞡

Constructs function to compute the Generalised Gauss-Newton matrix.

Equivalent to the (non-empirical) Fisher when loss is the negative log likelihood of an exponential family distribution as a function of its natural parameter.

Defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss function with scalar output.

Thus \(J_f(θ)\) is the Jacobian of the forward function \(f\) evaluated at primals \(θ\). And \(H_l(z)\) is the Hessian of the loss function \(l\) evaluated at z = f(θ).

Requires output from forward to be a tensor and therefore loss takes a tensor as input. Although both support aux output.

If normalize=True, then \(G(θ)\) is divided by the size of the leading dimension of outputs from forward (i.e. batchsize).

The GGN will be provided as a square tensor with respect to the ravelled parameters. flat_params, params_unravel = optree.tree_ravel(params).

Follows API from torch.func.jacrev.

More info on Fisher and GGN matrices can be found in Martens, 2020.

Examples:

from functools import partial
import torch
from posteriors import ggn

# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}

def forward(params, inputs):
    return torch.func.functional_call(model, params, inputs)

def loss(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels)

params = dict(model.parameters())
ggn_result = ggn(
    partial(forward, inputs=batch['inputs']),
    partial(loss, labels=batch['labels']),
)(params)

Parameters:

Name Type Description Default
forward Callable

A function with tensor output.

required
loss Callable

A function that maps the output of forward to a scalar output. Takes a single input and returns a scalar (and possibly aux).

required
argnums int | Sequence[int]

Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate forward with respect to.

0
forward_has_aux bool

Whether forward returns auxiliary information.

False
loss_has_aux bool

Whether loss returns auxiliary information.

False
normalize bool

Whether to normalize, divide by the first dimension of the output from f.

False

Returns:

Type Description
Callable

A function with the same arguments as f that returns the tensor GGN. If has_aux is True, then the function instead returns a tuple of (F, aux).

Source code in posteriors/utils.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def ggn(
    forward: Callable,
    loss: Callable,
    argnums: int | Sequence[int] = 0,
    forward_has_aux: bool = False,
    loss_has_aux: bool = False,
    normalize: bool = False,
) -> Callable:
    """
    Constructs function to compute the Generalised Gauss-Newton matrix.

    Equivalent to the (non-empirical) Fisher when `loss` is the negative
    log likelihood of an exponential family distribution as a function of its natural
    parameter.

    Defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
    is a loss function with scalar output.

    Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
    at `primals` $θ$. And $H_l(z)$ is the Hessian of the loss function $l$ evaluated
    at `z = f(θ)`.

    Requires output from `forward` to be a tensor and therefore `loss` takes a tensor as
    input. Although both support `aux` output.

    If `normalize=True`, then $G(θ)$ is divided by the size of the leading dimension of
    outputs from `forward` (i.e. batchsize).

    The GGN will be provided as a square tensor with respect to the
    ravelled parameters.
    `flat_params, params_unravel = optree.tree_ravel(params)`.

    Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).

    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Examples:
        ```python
        from functools import partial
        import torch
        from posteriors import ggn

        # Load model that outputs logits
        # Load batch = {'inputs': ..., 'labels': ...}

        def forward(params, inputs):
            return torch.func.functional_call(model, params, inputs)

        def loss(logits, labels):
            return torch.nn.functional.cross_entropy(logits, labels)

        params = dict(model.parameters())
        ggn_result = ggn(
            partial(forward, inputs=batch['inputs']),
            partial(loss, labels=batch['labels']),
        )(params)
        ```

    Args:
        forward: A function with tensor output.
        loss: A function that maps the output of forward to a scalar output.
            Takes a single input and returns a scalar (and possibly aux).
        argnums: Optional, integer or sequence of integers. Specifies which
            positional argument(s) to differentiate `forward` with respect to.
        forward_has_aux: Whether forward returns auxiliary information.
        loss_has_aux: Whether loss returns auxiliary information.
        normalize: Whether to normalize, divide by the first dimension of the output
            from f.

    Returns:
        A function with the same arguments as f that returns the tensor GGN.
            If has_aux is True, then the function instead returns a tuple of (F, aux).
    """
    assert argnums == 0, "Only argnums=0 is supported for now."

    def internal_ggn(params):
        flat_params, params_unravel = tree_ravel(params)

        def flat_params_to_forward(fps):
            return forward(params_unravel(fps))

        jac, hess, aux = _hess_and_jac_for_ggn(
            flat_params_to_forward,
            loss,
            argnums,
            forward_has_aux,
            loss_has_aux,
            normalize,
            flat_params,
        )

        if aux:
            return jac.T @ (hess @ jac), *aux
        else:
            return jac.T @ (hess @ jac)

    return internal_ggn

posteriors.utils.diag_ggn(forward, loss, argnums=0, forward_has_aux=False, loss_has_aux=False, normalize=False) 𝞡

Constructs function to compute the diagonal of the Generalised Gauss-Newton matrix.

Equivalent to the (non-empirical) diagonal Fisher when loss is the negative log likelihood of an exponential family distribution as a function of its natural parameter.

The GGN is defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss function with scalar output.

Thus \(J_f(θ)\) is the Jacobian of the forward function \(f\) evaluated at primals \(θ\). And \(H_l(z)\) is the Hessian of the loss function \(l\) evaluated at z = f(θ).

Requires output from forward to be a tensor and therefore loss takes a tensor as input. Although both support aux output.

If normalize=True, then \(G(θ)\) is divided by the size of the leading dimension of outputs from forward (i.e. batchsize).

Unlike posteriors.ggn, the output will be in PyTree form matching the input.

Follows API from torch.func.jacrev.

More info on Fisher and GGN matrices can be found in Martens, 2020.

Examples:

from functools import partial
import torch
from posteriors import diag_ggn

# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}

def forward(params, inputs):
    return torch.func.functional_call(model, params, inputs)

def loss(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels)

params = dict(model.parameters())
ggndiag_result = diag_ggn(
    partial(forward, inputs=batch['inputs']),
    partial(loss, labels=batch['labels']),
)(params)

Parameters:

Name Type Description Default
forward Callable

A function with tensor output.

required
loss Callable

A function that maps the output of forward to a scalar output. Takes a single input and returns a scalar (and possibly aux).

required
argnums int | Sequence[int]

Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate forward with respect to.

0
forward_has_aux bool

Whether forward returns auxiliary information.

False
loss_has_aux bool

Whether loss returns auxiliary information.

False
normalize bool

Whether to normalize, divide by the first dimension of the output from f.

False

Returns:

Type Description
Callable

A function with the same arguments as f that returns the diagonal GGN. If has_aux is True, then the function instead returns a tuple of (F, aux).

Source code in posteriors/utils.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
def diag_ggn(
    forward: Callable,
    loss: Callable,
    argnums: int | Sequence[int] = 0,
    forward_has_aux: bool = False,
    loss_has_aux: bool = False,
    normalize: bool = False,
) -> Callable:
    """
    Constructs function to compute the diagonal of the Generalised Gauss-Newton matrix.

    Equivalent to the (non-empirical) diagonal Fisher when `loss` is the negative
    log likelihood of an exponential family distribution as a function of its natural
    parameter.

    The GGN is defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
    is a loss function with scalar output.

    Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
    at `primals` $θ$. And $H_l(z)$ is the Hessian of the loss function $l$ evaluated
    at `z = f(θ)`.

    Requires output from `forward` to be a tensor and therefore `loss` takes a tensor as
    input. Although both support `aux` output.

    If `normalize=True`, then $G(θ)$ is divided by the size of the leading dimension of
    outputs from `forward` (i.e. batchsize).

    Unlike `posteriors.ggn`, the output will be in PyTree form matching the input.

    Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).

    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Examples:
        ```python
        from functools import partial
        import torch
        from posteriors import diag_ggn

        # Load model that outputs logits
        # Load batch = {'inputs': ..., 'labels': ...}

        def forward(params, inputs):
            return torch.func.functional_call(model, params, inputs)

        def loss(logits, labels):
            return torch.nn.functional.cross_entropy(logits, labels)

        params = dict(model.parameters())
        ggndiag_result = diag_ggn(
            partial(forward, inputs=batch['inputs']),
            partial(loss, labels=batch['labels']),
        )(params)
        ```

    Args:
        forward: A function with tensor output.
        loss: A function that maps the output of forward to a scalar output.
            Takes a single input and returns a scalar (and possibly aux).
        argnums: Optional, integer or sequence of integers. Specifies which
            positional argument(s) to differentiate `forward` with respect to.
        forward_has_aux: Whether forward returns auxiliary information.
        loss_has_aux: Whether loss returns auxiliary information.
        normalize: Whether to normalize, divide by the first dimension of the output
            from f.

    Returns:
        A function with the same arguments as f that returns the diagonal GGN.
            If has_aux is True, then the function instead returns a tuple of (F, aux).
    """
    assert argnums == 0, "Only argnums=0 is supported for now."

    def internal_ggn(params):
        flat_params, params_unravel = tree_ravel(params)

        def flat_params_to_forward(fps):
            return forward(params_unravel(fps))

        jac, hess, aux = _hess_and_jac_for_ggn(
            flat_params_to_forward,
            loss,
            argnums,
            forward_has_aux,
            loss_has_aux,
            normalize,
            flat_params,
        )

        G_diag = torch.einsum("ji,jk,ki->i", jac, hess, jac)
        G_diag = params_unravel(G_diag)

        if aux:
            return G_diag, *aux
        else:
            return G_diag

    return internal_ggn

posteriors.utils.cg(A, b, x0=None, *, maxiter=None, damping=0.0, tol=1e-05, atol=0.0, M=_identity) 𝞡

Use Conjugate Gradient iteration to solve Ax = b. A is supplied as a function instead of a matrix.

Adapted from jax.scipy.sparse.linalg.cg.

Parameters:

Name Type Description Default
A Callable

Callable that calculates the linear map (matrix-vector product) Ax when called like A(x). A must represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument.

required
b TensorTree

Right hand side of the linear system representing a single vector.

required
x0 TensorTree

Starting guess for the solution. Must have the same structure as b.

None
maxiter int

Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.

None
damping float

damping term for the mvp function. Acts as regularization.

0.0
tol float

Tolerance for convergence.

1e-05
atol float

Tolerance for convergence. norm(residual) <= max(tol*norm(b), atol). The behaviour will differ from SciPy unless you explicitly pass atol to SciPy's cg.

0.0
M Callable

Preconditioner for A. See the preconditioned CG method.

_identity

Returns:

Name Type Description
x TensorTree

The converged solution. Has the same structure as b.

info Any

Placeholder for convergence information.

Source code in posteriors/utils.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
def cg(
    A: Callable,
    b: TensorTree,
    x0: TensorTree = None,
    *,
    maxiter: int = None,
    damping: float = 0.0,
    tol: float = 1e-5,
    atol: float = 0.0,
    M: Callable = _identity,
) -> Tuple[TensorTree, Any]:
    """Use Conjugate Gradient iteration to solve ``Ax = b``.
    ``A`` is supplied as a function instead of a matrix.

    Adapted from [`jax.scipy.sparse.linalg.cg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.cg.html).

    Args:
        A:  Callable that calculates the linear map (matrix-vector
            product) ``Ax`` when called like ``A(x)``. ``A`` must represent
            a hermitian, positive definite matrix, and must return array(s) with the
            same structure and shape as its argument.
        b:  Right hand side of the linear system representing a single vector.
        x0: Starting guess for the solution. Must have the same structure as ``b``.
        maxiter: Maximum number of iterations.  Iteration will stop after maxiter
            steps even if the specified tolerance has not been achieved.
        damping: damping term for the mvp function. Acts as regularization.
        tol: Tolerance for convergence.
        atol: Tolerance for convergence. ``norm(residual) <= max(tol*norm(b), atol)``.
            The behaviour will differ from SciPy unless you explicitly pass
            ``atol`` to SciPy's ``cg``.
        M: Preconditioner for A.
            See [the preconditioned CG method.](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method)

    Returns:
        x : The converged solution. Has the same structure as ``b``.
        info : Placeholder for convergence information.
    """
    if x0 is None:
        x0 = tree_map(torch.zeros_like, b)

    if maxiter is None:
        maxiter = 10 * tree_size(b)  # copied from scipy

    tol *= torch.tensor([1.0])
    atol *= torch.tensor([1.0])

    # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
    bs = _vdot_real_tree(b, b)
    atol2 = torch.maximum(torch.square(tol) * bs, torch.square(atol))

    def A_damped(p):
        return _add(A(p), _mul(damping, p))

    def cond_fun(value):
        _, r, gamma, _, k = value
        rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
        return (rs > atol2) & (k < maxiter)

    def body_fun(value):
        x, r, gamma, p, k = value
        Ap = A_damped(p)
        alpha = gamma / _vdot_real_tree(p, Ap)
        x_ = _add(x, _mul(alpha, p))
        r_ = _sub(r, _mul(alpha, Ap))
        z_ = M(r_)
        gamma_ = _vdot_real_tree(r_, z_)
        beta_ = gamma_ / gamma
        p_ = _add(z_, _mul(beta_, p))
        return x_, r_, gamma_, p_, k + 1

    r0 = _sub(b, A_damped(x0))
    p0 = z0 = r0
    gamma0 = _vdot_real_tree(r0, z0)
    initial_value = (x0, r0, gamma0, p0, 0)

    value = initial_value

    while cond_fun(value):
        value = body_fun(value)

    x_final, r, gamma, _, k = value
    # compute the final error and whether it has converged.
    rs = gamma if M is _identity else _vdot_real_tree(r, r)
    converged = rs <= atol2

    # additional info output structure
    info = {"error": rs, "converged": converged, "niter": k}

    return x_final, info

posteriors.utils.diag_normal_log_prob(x, mean=0.0, sd_diag=1.0, normalize=True) 𝞡

Evaluate multivariate normal log probability for a diagonal covariance matrix.

If either mean or sd_diag are scalars, they will be broadcast to the same shape as x (in a memory efficient manner).

Parameters:

Name Type Description Default
x TensorTree

Value to evaluate log probability at.

required
mean float | TensorTree

Mean of the distribution.

0.0
sd_diag float | TensorTree

Square-root diagonal of the covariance matrix.

1.0
normalize bool

Whether to compute normalized log probability. If False the elementwise log prob is -0.5 * ((x - mean) / sd_diag)**2.

True

Returns:

Type Description
float

Scalar log probability.

Source code in posteriors/utils.py
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
def diag_normal_log_prob(
    x: TensorTree,
    mean: float | TensorTree = 0.0,
    sd_diag: float | TensorTree = 1.0,
    normalize: bool = True,
) -> float:
    """Evaluate multivariate normal log probability for a diagonal covariance matrix.

    If either mean or sd_diag are scalars, they will be broadcast to the same shape as x
    (in a memory efficient manner).

    Args:
        x: Value to evaluate log probability at.
        mean: Mean of the distribution.
        sd_diag: Square-root diagonal of the covariance matrix.
        normalize: Whether to compute normalized log probability.
            If False the elementwise log prob is -0.5 * ((x - mean) / sd_diag)**2.

    Returns:
        Scalar log probability.
    """
    if tree_size(mean) == 1:
        mean = tree_map(lambda t: torch.tensor(mean, device=t.device), x)
    if tree_size(sd_diag) == 1:
        sd_diag = tree_map(lambda t: torch.tensor(sd_diag, device=t.device), x)

    if normalize:

        def univariate_norm_and_sum(v, m, sd):
            return Normal(m, sd, validate_args=False).log_prob(v).sum()
    else:

        def univariate_norm_and_sum(v, m, sd):
            return (-0.5 * ((v - m) / sd) ** 2).sum()

    log_probs = tree_map(
        univariate_norm_and_sum,
        x,
        mean,
        sd_diag,
    )
    log_prob = tree_reduce(torch.add, log_probs)
    return log_prob

posteriors.utils.diag_normal_sample(mean, sd_diag, sample_shape=torch.Size([])) 𝞡

Sample from multivariate normal with diagonal covariance matrix.

If sd_diag is scalar, it will be broadcast to the same shape as mean (in a memory efficient manner).

Parameters:

Name Type Description Default
mean TensorTree

Mean of the distribution.

required
sd_diag float | TensorTree

Square-root diagonal of the covariance matrix.

required
sample_shape Size

Shape of the sample.

Size([])

Returns:

Type Description
dict

Sample(s) from normal distribution with the same structure as mean and sd_diag.

Source code in posteriors/utils.py
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
def diag_normal_sample(
    mean: TensorTree,
    sd_diag: float | TensorTree,
    sample_shape: torch.Size = torch.Size([]),
) -> dict:
    """Sample from multivariate normal with diagonal covariance matrix.

    If sd_diag is scalar, it will be broadcast to the same shape as mean
    (in a memory efficient manner).

    Args:
        mean: Mean of the distribution.
        sd_diag: Square-root diagonal of the covariance matrix.
        sample_shape: Shape of the sample.

    Returns:
        Sample(s) from normal distribution with the same structure as mean and sd_diag.
    """
    if tree_size(sd_diag) == 1:
        sd_diag = tree_map(lambda t: torch.tensor(sd_diag, device=t.device), mean)

    return tree_map(
        lambda m, sd: m + torch.randn(sample_shape + m.shape, device=m.device) * sd,
        mean,
        sd_diag,
    )

posteriors.utils.per_samplify(f) 𝞡

Converts a function that takes params and batch into one that provides an output for each batch sample.

output = f(params, batch)
per_sample_output = per_samplify(f)(params, batch)

For more info see per_sample_grads.html

Parameters:

Name Type Description Default
f Callable[[TensorTree, TensorTree], Any]

A function that takes params and batch provides an output with size independent of batchsize (i.e. averaged).

required

Returns:

Type Description
Callable[[TensorTree, TensorTree], Any]

A new function that provides an output for each batch sample. per_sample_output = per_samplify(f)(params, batch)

Source code in posteriors/utils.py
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
def per_samplify(
    f: Callable[[TensorTree, TensorTree], Any],
) -> Callable[[TensorTree, TensorTree], Any]:
    """Converts a function that takes params and batch into one that provides an output
    for each batch sample.

    ```
    output = f(params, batch)
    per_sample_output = per_samplify(f)(params, batch)
    ```

    For more info see [per_sample_grads.html](https://pytorch.org/tutorials/intermediate/per_sample_grads.html)

    Args:
        f: A function that takes params and batch provides an output with size
            independent of batchsize (i.e. averaged).

    Returns:
        A new function that provides an output for each batch sample.
            `per_sample_output  = per_samplify(f)(params, batch)`
    """

    @partial(torch.vmap, in_dims=(None, 0))
    def f_per_sample(params, batch):
        batch = tree_map(lambda x: x.unsqueeze(0), batch)
        return f(params, batch)

    @wraps(f)
    def f_per_sample_ensure_no_kwargs(params, batch):
        return f_per_sample(params, batch)  # vmap in_dims requires no kwargs

    return f_per_sample_ensure_no_kwargs

posteriors.utils.is_scalar(x) 𝞡

Returns True if x is a scalar (int, float, bool) or a tensor with a single element.

Parameters:

Name Type Description Default
x Any

Any object.

required

Returns:

Type Description
bool

True if x is a scalar.

Source code in posteriors/utils.py
883
884
885
886
887
888
889
890
891
892
def is_scalar(x: Any) -> bool:
    """Returns True if x is a scalar (int, float, bool) or a tensor with a single element.

    Args:
        x: Any object.

    Returns:
        True if x is a scalar.
    """
    return isinstance(x, (int, float)) or (torch.is_tensor(x) and x.numel() == 1)