Laplace Diagonal GGN𝞡
posteriors.laplace.diag_ggn.build(forward, outer_log_likelihood, init_prec_diag=0.0)
𝞡
Builds a transform for a diagonal Generalized Gauss-Newton (GGN) Laplace approximation.
Equivalent to the diagonal of the (non-empirical) Fisher information matrix when
the outer_log_likelihood
is exponential family with natural parameter equal to
the output from forward
.
forward
should output auxiliary information (or torch.tensor([])
),
outer_log_likelihood
should not.
The GGN is defined as $$ G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) is a loss (negative log-likelihood) that maps the output of \(f\) to a scalar output.
More info on Fisher and GGN matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
forward
|
ForwardFn
|
Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information. |
required |
outer_log_likelihood
|
OuterLogProbFn
|
A function that takes the output of |
required |
init_prec_diag
|
TensorTree | float
|
Initial diagonal precision matrix. Can be tree like params or scalar. |
0.0
|
Returns:
Type | Description |
---|---|
Transform
|
Diagonal GGN Laplace approximation transform instance. |
Source code in posteriors/laplace/diag_ggn.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
posteriors.laplace.diag_ggn.DiagLaplaceState
𝞡
Bases: NamedTuple
State encoding a diagonal Normal distribution over parameters.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Mean of the Normal distribution. |
prec_diag |
TensorTree
|
Diagonal of the precision matrix of the Normal distribution. |
aux |
Any
|
Auxiliary information from the log_posterior call. |
Source code in posteriors/laplace/diag_ggn.py
67 68 69 70 71 72 73 74 75 76 77 78 |
|
posteriors.laplace.diag_ggn.init(params, init_prec_diag=0.0)
𝞡
Initialise diagonal Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Mean of the Normal distribution. |
required |
init_prec_diag
|
TensorTree | float
|
Initial diagonal precision matrix. Can be tree like params or scalar. |
0.0
|
Returns:
Type | Description |
---|---|
DiagLaplaceState
|
Initial DiagLaplaceState. |
Source code in posteriors/laplace/diag_ggn.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
|
posteriors.laplace.diag_ggn.update(state, batch, forward, outer_log_likelihood, inplace=False)
𝞡
Adds diagonal GGN matrix of covariance summed over given batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DiagLaplaceState
|
Current state. |
required |
batch
|
Any
|
Input data to model. |
required |
forward
|
ForwardFn
|
Function that takes parameters and input batch and returns a forward value (e.g. logits), not reduced over the batch, as well as auxiliary information. |
required |
outer_log_likelihood
|
OuterLogProbFn
|
A function that takes the output of |
required |
inplace
|
bool
|
If True, then the state is updated in place, otherwise a new state is returned. |
False
|
Returns:
Type | Description |
---|---|
DiagLaplaceState
|
Updated DiagLaplaceState. |
Source code in posteriors/laplace/diag_ggn.py
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
posteriors.laplace.diag_ggn.sample(state, sample_shape=torch.Size([]))
𝞡
Sample from diagonal Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DiagLaplaceState
|
State encoding mean and diagonal precision. |
required |
sample_shape
|
Size
|
Shape of the desired samples. |
Size([])
|
Returns:
Type | Description |
---|---|
TensorTree
|
Sample(s) from Normal distribution. |
Source code in posteriors/laplace/diag_ggn.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
|