Laplace Diagonal Fisher𝞡
posteriors.laplace.diag_fisher.build(log_posterior, per_sample=False, init_prec_diag=0.0)
𝞡
Builds a transform for diagonal empirical Fisher information Laplace approximation.
The empirical Fisher is defined here as: $$ F(θ) = \sum_i ∇_θ \log p(y_i, θ | x_i) ∇_θ \log p(y_i, θ | x_i)^T $$ where \(p(y_i, θ | x_i)\) is the joint model distribution (equivalent to the posterior up to proportionality) with parameters \(θ\), inputs \(x_i\) and labels \(y_i\).
More info on empirical Fisher matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
per_sample
|
bool
|
If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample. |
False
|
init_prec_diag
|
TensorTree | float
|
Initial diagonal precision matrix. Can be tree like params or scalar. |
0.0
|
Returns:
Type | Description |
---|---|
Transform
|
Diagonal empirical Fisher information Laplace approximation transform instance. |
Source code in posteriors/laplace/diag_fisher.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
|
posteriors.laplace.diag_fisher.DiagLaplaceState
𝞡
Bases: NamedTuple
State encoding a diagonal Normal distribution over parameters.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Mean of the Normal distribution. |
prec_diag |
TensorTree
|
Diagonal of the precision matrix of the Normal distribution. |
aux |
Any
|
Auxiliary information from the log_posterior call. |
Source code in posteriors/laplace/diag_fisher.py
56 57 58 59 60 61 62 63 64 65 66 67 |
|
posteriors.laplace.diag_fisher.init(params, init_prec_diag=0.0)
𝞡
Initialise diagonal Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Mean of the Normal distribution. |
required |
init_prec_diag
|
TensorTree | float
|
Initial diagonal precision matrix. Can be tree like params or scalar. |
0.0
|
Returns:
Type | Description |
---|---|
DiagLaplaceState
|
Initial DiagLaplaceState. |
Source code in posteriors/laplace/diag_fisher.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
|
posteriors.laplace.diag_fisher.update(state, batch, log_posterior, per_sample=False, inplace=False)
𝞡
Adds diagonal empirical Fisher information matrix of covariance summed over given batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DiagLaplaceState
|
Current state. |
required |
batch
|
Any
|
Input data to log_posterior. |
required |
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
per_sample
|
bool
|
If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample. |
False
|
inplace
|
bool
|
If True, then the state is updated in place, otherwise a new state is returned. |
False
|
Returns:
Type | Description |
---|---|
DiagLaplaceState
|
Updated DiagLaplaceState. |
Source code in posteriors/laplace/diag_fisher.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
|
posteriors.laplace.diag_fisher.sample(state, sample_shape=torch.Size([]))
𝞡
Sample from diagonal Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DiagLaplaceState
|
State encoding mean and diagonal precision. |
required |
sample_shape
|
Size
|
Shape of the desired samples. |
Size([])
|
Returns:
Type | Description |
---|---|
TensorTree
|
Sample(s) from Normal distribution. |
Source code in posteriors/laplace/diag_fisher.py
139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
|