EKF Diagonal Fisher𝞡
posteriors.ekf.diag_fisher.build(log_likelihood, lr, transition_sd=0.0, per_sample=False, init_sds=1.0)
𝞡
Builds a transform to implement an extended Kalman Filter update.
EKF applies an online update to a (diagonal) Gaussian posterior over the parameters.
The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F_d(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F_d(μ)\) is the diagonal empirical Fisher information matrix at \(μ\) for data \(y\). Completing the square regains a diagonal Normal distribution over the parameters.
For more information on extended Kalman filtering as well as an equivalence to (online) natural gradient descent see Ollivier, 2019.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_likelihood
|
LogProbFn
|
Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Inverse temperature of the update, which behaves like a learning rate. |
required |
transition_sd
|
float
|
Standard deviation of the transition noise, to additively inflate the diagonal covariance before the update. |
0.0
|
per_sample
|
bool
|
If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample. |
False
|
init_sds
|
TensorTree | float
|
Initial square-root diagonal of the covariance matrix of the Normal distribution. Can be tree like params or scalar. |
1.0
|
Returns:
Type | Description |
---|---|
Transform
|
Diagonal EKF transform instance. |
Source code in posteriors/ekf/diag_fisher.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
posteriors.ekf.diag_fisher.EKFDiagState
𝞡
Bases: NamedTuple
State encoding a diagonal Normal distribution over parameters.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Mean of the Normal distribution. |
sd_diag |
TensorTree
|
Square-root diagonal of the covariance matrix of the Normal distribution. |
log_likelihood |
Tensor
|
Log likelihood of the data given the parameters. |
aux |
Any
|
Auxiliary information from the log_likelihood call. |
Source code in posteriors/ekf/diag_fisher.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|
posteriors.ekf.diag_fisher.init(params, init_sds=1.0)
𝞡
Initialise diagonal Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Initial mean of the Normal distribution. |
required |
init_sds
|
TensorTree | float
|
Initial square-root diagonal of the covariance matrix of the Normal distribution. Can be tree like params or scalar. |
1.0
|
Returns:
Type | Description |
---|---|
EKFDiagState
|
Initial EKFDiagState. |
Source code in posteriors/ekf/diag_fisher.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|
posteriors.ekf.diag_fisher.update(state, batch, log_likelihood, lr, transition_sd=0.0, per_sample=False, inplace=False)
𝞡
Applies an extended Kalman Filter update to the diagonal Normal distribution. The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F_d(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F_d(μ)\) is the diagonal empirical Fisher information matrix at \(μ\) for data \(y\). Completing the square regains a diagonal Normal distribution over the parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
EKFDiagState
|
Current state. |
required |
batch
|
Any
|
Input data to log_likelihood. |
required |
log_likelihood
|
LogProbFn
|
Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Inverse temperature of the update, which behaves like a learning rate. |
required |
transition_sd
|
float
|
Standard deviation of the transition noise, to additively inflate the diagonal covariance before the update. |
0.0
|
per_sample
|
bool
|
If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample. |
False
|
inplace
|
bool
|
Whether to update the state parameters in-place. |
False
|
Returns:
Type | Description |
---|---|
EKFDiagState
|
Updated EKFDiagState. |
Source code in posteriors/ekf/diag_fisher.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
|
posteriors.ekf.diag_fisher.sample(state, sample_shape=torch.Size([]))
𝞡
Single sample from diagonal Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
EKFDiagState
|
State encoding mean and standard deviations. |
required |
sample_shape
|
Size
|
Shape of the desired samples. |
Size([])
|
Returns:
Type | Description |
---|---|
TensorTree
|
Sample(s) from Normal distribution. |
Source code in posteriors/ekf/diag_fisher.py
183 184 185 186 187 188 189 190 191 192 193 194 195 |
|