Laplace Dense GGN𝞡
posteriors.laplace.dense_ggn.build(forward, outer_log_likelihood, init_prec=0.0)
𝞡
Builds a transform for a Generalized Gauss-Newton (GGN) Laplace approximation.
Equivalent to the (non-empirical) Fisher information matrix when
the outer_log_likelihood
is exponential family with natural parameter equal to
the output from forward
.
forward
should output auxiliary information (or torch.tensor([])
),
outer_log_likelihood
should not.
The GGN is defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss (negative log-likelihood) that maps the output of \(f\) to a scalar output.
More info on Fisher and GGN matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
forward
|
ForwardFn
|
Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information. |
required |
outer_log_likelihood
|
OuterLogProbFn
|
A function that takes the output of |
required |
init_prec
|
TensorTree | float
|
Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float. |
0.0
|
Returns:
Type | Description |
---|---|
Transform
|
GGN Laplace approximation transform instance. |
Source code in posteriors/laplace/dense_ggn.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
|
posteriors.laplace.dense_ggn.DenseLaplaceState
𝞡
Bases: NamedTuple
State encoding a Normal distribution over parameters, with a dense precision matrix
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Mean of the Normal distribution. |
prec |
Tensor
|
Precision matrix of the Normal distribution. |
aux |
Any
|
Auxiliary information from the log_posterior call. |
Source code in posteriors/laplace/dense_ggn.py
68 69 70 71 72 73 74 75 76 77 78 79 80 |
|
posteriors.laplace.dense_ggn.init(params, init_prec=0.0)
𝞡
Initialise Normal distribution over parameters with a dense precision matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Mean of the Normal distribution. |
required |
init_prec
|
Tensor | float
|
Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float. |
0.0
|
Returns:
Type | Description |
---|---|
DenseLaplaceState
|
Initial DenseLaplaceState. |
Source code in posteriors/laplace/dense_ggn.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
|
posteriors.laplace.dense_ggn.update(state, batch, forward, outer_log_likelihood, inplace=False)
𝞡
Adds GGN matrix over given batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DenseLaplaceState
|
Current state. |
required |
batch
|
Any
|
Input data to model. |
required |
forward
|
ForwardFn
|
Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information. |
required |
outer_log_likelihood
|
OuterLogProbFn
|
A function that takes the output of |
required |
inplace
|
bool
|
If True, then the state is updated in place, otherwise a new state is returned. |
False
|
Returns:
Type | Description |
---|---|
DenseLaplaceState
|
Updated DenseLaplaceState. |
Source code in posteriors/laplace/dense_ggn.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
|
posteriors.laplace.dense_ggn.sample(state, sample_shape=torch.Size([]))
𝞡
Sample from Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DenseLaplaceState
|
State encoding mean and precision matrix. |
required |
sample_shape
|
Size
|
Shape of the desired samples. |
Size([])
|
Returns:
Type | Description |
---|---|
TensorTree
|
Sample(s) from the Normal distribution. |
Source code in posteriors/laplace/dense_ggn.py
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
|