EKF Dense Fisher𝞡
posteriors.ekf.dense_fisher.build(log_likelihood, lr, transition_cov=0.0, per_sample=False, init_cov=1.0)
𝞡
Builds a transform to implement an extended Kalman Filter update.
EKF applies an online update to a Gaussian posterior over the parameters.
The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F(μ)\) is the empirical Fisher information matrix at \(μ\) for data \(y\).
For more information on extended Kalman filtering as well as an equivalence to (online) natural gradient descent see Ollivier, 2019.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_likelihood
|
LogProbFn
|
Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Inverse temperature of the update, which behaves like a learning rate. |
required |
transition_cov
|
Tensor | float
|
Covariance of the transition noise, to additively inflate the covariance before the update. |
0.0
|
per_sample
|
bool
|
If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample. |
False
|
init_cov
|
Tensor | float
|
Initial covariance of the Normal distribution. Can be torch.Tensor or scalar. |
1.0
|
Returns:
Type | Description |
---|---|
Transform
|
EKF transform instance. |
Source code in posteriors/ekf/dense_fisher.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
posteriors.ekf.dense_fisher.EKFDenseState
𝞡
Bases: NamedTuple
State encoding a Normal distribution over parameters.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Mean of the Normal distribution. |
cov |
Tensor
|
Covariance matrix of the Normal distribution. |
log_likelihood |
Tensor
|
Log likelihood of the data given the parameters. |
aux |
Any
|
Auxiliary information from the log_likelihood call. |
Source code in posteriors/ekf/dense_fisher.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
|
posteriors.ekf.dense_fisher.init(params, init_cov=1.0)
𝞡
Initialise Multivariate Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Initial mean of the Normal distribution. |
required |
init_cov
|
Tensor | float
|
Initial covariance matrix of the Multivariate Normal distribution. If it is a float, it is defined as an identity matrix scaled by that float. |
1.0
|
Returns:
Type | Description |
---|---|
EKFDenseState
|
Initial EKFDenseState. |
Source code in posteriors/ekf/dense_fisher.py
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
|
posteriors.ekf.dense_fisher.update(state, batch, log_likelihood, lr, transition_cov=0.0, per_sample=False, inplace=False)
𝞡
Applies an extended Kalman Filter update to the Multivariate Normal distribution. The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F(μ)\) is the empirical Fisher information matrix at \(μ\) for data \(y\).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
EKFDenseState
|
Current state. |
required |
batch
|
Any
|
Input data to log_likelihood. |
required |
log_likelihood
|
LogProbFn
|
Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Inverse temperature of the update, which behaves like a learning rate. |
required |
transition_cov
|
Tensor | float
|
Covariance of the transition noise, to additively inflate the covariance before the update. |
0.0
|
per_sample
|
bool
|
If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample. |
False
|
inplace
|
bool
|
Whether to update the state parameters in-place. |
False
|
Returns:
Type | Description |
---|---|
EKFDenseState
|
Updated EKFDenseState. |
Source code in posteriors/ekf/dense_fisher.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
|
posteriors.ekf.dense_fisher.sample(state, sample_shape=torch.Size([]))
𝞡
Single sample from Multivariate Normal distribution over parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
EKFDenseState
|
State encoding mean and covariance. |
required |
sample_shape
|
Size
|
Shape of the desired samples. |
Size([])
|
Returns:
Type | Description |
---|---|
TensorTree
|
Sample(s) from Multivariate Normal distribution. |
Source code in posteriors/ekf/dense_fisher.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
|