Laplace Dense Fisher𝞡
posteriors.laplace.dense_fisher.build(log_posterior, per_sample=False, init_prec=0.0)
𝞡
Builds a transform for dense empirical Fisher information Laplace approximation.
The empirical Fisher is defined here as: $$ F(θ) = \sum_i ∇_θ \log p(y_i, θ | x_i) ∇_θ \log p(y_i, θ | x_i)^T $$ where \(p(y_i, θ | x_i)\) is the joint model distribution (equivalent to the posterior up to proportionality) with parameters \(θ\), inputs \(x_i\) and labels \(y_i\).
More info on empirical Fisher matrices can be found in Martens, 2020 and their use within a Laplace approximation in Daxberger et al, 2021.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
per_sample
|
bool
|
If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample. |
False
|
init_prec
|
Tensor | float
|
Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float. |
0.0
|
Returns:
Type | Description |
---|---|
Transform
|
Empirical Fisher information Laplace approximation transform instance. |
Source code in posteriors/laplace/dense_fisher.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
|
posteriors.laplace.dense_fisher.DenseLaplaceState
𝞡
Bases: NamedTuple
State encoding a Normal distribution over parameters, with a dense precision matrix
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Mean of the Normal distribution. |
prec |
Tensor
|
Precision matrix of the Normal distribution. |
aux |
Any
|
Auxiliary information from the log_posterior call. |
Source code in posteriors/laplace/dense_fisher.py
57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
posteriors.laplace.dense_fisher.init(params, init_prec=0.0)
𝞡
Initialise Normal distribution over parameters with a dense precision matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Mean of the Normal distribution. |
required |
init_prec
|
Tensor | float
|
Initial precision matrix. If it is a float, it is defined as an identity matrix scaled by that float. |
0.0
|
Returns:
Type | Description |
---|---|
DenseLaplaceState
|
Initial DenseLaplaceState. |
Source code in posteriors/laplace/dense_fisher.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
|
posteriors.laplace.dense_fisher.update(state, batch, log_posterior, per_sample=False, inplace=False)
𝞡
Adds empirical Fisher information matrix of covariance summed over given batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DenseLaplaceState
|
Current state. |
required |
batch
|
Any
|
Input data to log_posterior. |
required |
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) |
required |
per_sample
|
bool
|
If True, then log_posterior is assumed to return a vector of log posteriors for each sample in the batch. If False, then log_posterior is assumed to return a scalar log posterior for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_posterior to be per sample. |
False
|
inplace
|
bool
|
If True, the state is updated in place. Otherwise, a new state is returned. |
False
|
Returns:
Type | Description |
---|---|
DenseLaplaceState
|
Updated DenseLaplaceState. |
Source code in posteriors/laplace/dense_fisher.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
|
posteriors.laplace.dense_fisher.sample(state, sample_shape=torch.Size([]))
𝞡
Sample from Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
DenseLaplaceState
|
State encoding mean and precision matrix. |
required |
sample_shape
|
Size
|
Shape of the desired samples. |
Size([])
|
Returns:
Type | Description |
---|---|
TensorTree
|
Sample(s) from the Normal distribution. |
Source code in posteriors/laplace/dense_fisher.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
|