Skip to content

EKF Dense Fisher𝞡

posteriors.ekf.dense_fisher.build(log_likelihood, lr, transition_cov=0.0, per_sample=False, init_cov=1.0) 𝞡

Builds a transform to implement an extended Kalman Filter update.

EKF applies an online update to a Gaussian posterior over the parameters.

The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F(μ)\) is the empirical Fisher information matrix at \(μ\) for data \(y\).

For more information on extended Kalman filtering as well as an equivalence to (online) natural gradient descent see Ollivier, 2019.

Parameters:

Name Type Description Default
log_likelihood LogProbFn

Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call.

required
lr float

Inverse temperature of the update, which behaves like a learning rate.

required
transition_cov Tensor | float

Covariance of the transition noise, to additively inflate the covariance before the update.

0.0
per_sample bool

If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample.

False
init_cov Tensor | float

Initial covariance of the Normal distribution. Can be torch.Tensor or scalar.

1.0

Returns:

Type Description
Transform

EKF transform instance.

Source code in posteriors/ekf/dense_fisher.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def build(
    log_likelihood: LogProbFn,
    lr: float,
    transition_cov: torch.Tensor | float = 0.0,
    per_sample: bool = False,
    init_cov: torch.Tensor | float = 1.0,
) -> Transform:
    """Builds a transform to implement an extended Kalman Filter update.

    EKF applies an online update to a Gaussian posterior over the parameters.

    The approximate Bayesian update is based on the linearization
    $$
    \\log p(θ | y) ≈ \\log p(θ) +  ε g(μ)ᵀ(θ - μ) +  \\frac12 ε (θ - μ)^T F(μ) (θ - μ)
    $$
    where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
    (or equivalently the likelihood inverse temperature),
    $g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the
    empirical Fisher information matrix at $μ$ for data $y$.

    For more information on extended Kalman filtering as well as an equivalence
    to (online) natural gradient descent see [Ollivier, 2019](https://arxiv.org/abs/1703.00209).

    Args:
        log_likelihood: Function that takes parameters and input batch and
            returns the log-likelihood value as well as auxiliary information,
            e.g. from the model call.
        lr: Inverse temperature of the update, which behaves like a learning rate.
        transition_cov: Covariance of the transition noise, to additively
            inflate the covariance before the update.
        per_sample: If True, then log_likelihood is assumed to return a vector of
            log likelihoods for each sample in the batch. If False, then log_likelihood
            is assumed to return a scalar log likelihood for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_likelihood to be per sample.
        init_cov: Initial covariance of the Normal distribution. Can be torch.Tensor or scalar.

    Returns:
        EKF transform instance.
    """
    init_fn = partial(init, init_cov=init_cov)
    update_fn = partial(
        update,
        log_likelihood=log_likelihood,
        lr=lr,
        transition_cov=transition_cov,
        per_sample=per_sample,
    )
    return Transform(init_fn, update_fn)

posteriors.ekf.dense_fisher.EKFDenseState 𝞡

Bases: NamedTuple

State encoding a Normal distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

cov Tensor

Covariance matrix of the Normal distribution.

log_likelihood Tensor

Log likelihood of the data given the parameters.

aux Any

Auxiliary information from the log_likelihood call.

Source code in posteriors/ekf/dense_fisher.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class EKFDenseState(NamedTuple):
    """State encoding a Normal distribution over parameters.

    Attributes:
        params: Mean of the Normal distribution.
        cov: Covariance matrix of the
            Normal distribution.
        log_likelihood: Log likelihood of the data given the parameters.
        aux: Auxiliary information from the log_likelihood call.
    """

    params: TensorTree
    cov: torch.Tensor
    log_likelihood: torch.Tensor = torch.tensor([])
    aux: Any = None

posteriors.ekf.dense_fisher.init(params, init_cov=1.0) 𝞡

Initialise Multivariate Normal distribution over parameters.

Parameters:

Name Type Description Default
params TensorTree

Initial mean of the Normal distribution.

required
init_cov Tensor | float

Initial covariance matrix of the Multivariate Normal distribution. If it is a float, it is defined as an identity matrix scaled by that float.

1.0

Returns:

Type Description
EKFDenseState

Initial EKFDenseState.

Source code in posteriors/ekf/dense_fisher.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def init(
    params: TensorTree,
    init_cov: torch.Tensor | float = 1.0,
) -> EKFDenseState:
    """Initialise Multivariate Normal distribution over parameters.

    Args:
        params: Initial mean of the Normal distribution.
        init_cov: Initial covariance matrix of the Multivariate Normal distribution.
            If it is a float, it is defined as an identity matrix scaled by that float.

    Returns:
        Initial EKFDenseState.
    """
    if is_scalar(init_cov):
        num_params = tree_size(params)
        init_cov = init_cov * torch.eye(num_params, requires_grad=False)

    return EKFDenseState(params, init_cov)

posteriors.ekf.dense_fisher.update(state, batch, log_likelihood, lr, transition_cov=0.0, per_sample=False, inplace=False) 𝞡

Applies an extended Kalman Filter update to the Multivariate Normal distribution. The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F(μ)\) is the empirical Fisher information matrix at \(μ\) for data \(y\).

Parameters:

Name Type Description Default
state EKFDenseState

Current state.

required
batch Any

Input data to log_likelihood.

required
log_likelihood LogProbFn

Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call.

required
lr float

Inverse temperature of the update, which behaves like a learning rate.

required
transition_cov Tensor | float

Covariance of the transition noise, to additively inflate the covariance before the update.

0.0
per_sample bool

If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample.

False
inplace bool

Whether to update the state parameters in-place.

False

Returns:

Type Description
EKFDenseState

Updated EKFDenseState.

Source code in posteriors/ekf/dense_fisher.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def update(
    state: EKFDenseState,
    batch: Any,
    log_likelihood: LogProbFn,
    lr: float,
    transition_cov: torch.Tensor | float = 0.0,
    per_sample: bool = False,
    inplace: bool = False,
) -> EKFDenseState:
    """Applies an extended Kalman Filter update to the Multivariate Normal distribution.
    The approximate Bayesian update is based on the linearization
    $$
    \\log p(θ | y) ≈ \\log p(θ) +  ε g(μ)ᵀ(θ - μ) +  \\frac12 ε (θ - μ)^T F(μ) (θ - μ)
    $$
    where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
    (or equivalently the likelihood inverse temperature),
    $g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the
    empirical Fisher information matrix at $μ$ for data $y$.

    Args:
        state: Current state.
        batch: Input data to log_likelihood.
        log_likelihood: Function that takes parameters and input batch and
            returns the log-likelihood value as well as auxiliary information,
            e.g. from the model call.
        lr: Inverse temperature of the update, which behaves like a learning rate.
        transition_cov: Covariance of the transition noise, to additively
            inflate the covariance before the update.
        per_sample: If True, then log_likelihood is assumed to return a vector of
            log likelihoods for each sample in the batch. If False, then log_likelihood
            is assumed to return a scalar log likelihood for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_likelihood to be per sample.
        inplace: Whether to update the state parameters in-place.

    Returns:
        Updated EKFDenseState.
    """
    if not per_sample:
        log_likelihood = per_samplify(log_likelihood)

    with torch.no_grad(), CatchAuxError():

        def log_likelihood_reduced(params, batch):
            per_samp_log_lik, internal_aux = log_likelihood(params, batch)
            return per_samp_log_lik.mean(), internal_aux

        grad, (log_liks, aux) = grad_and_value(log_likelihood_reduced, has_aux=True)(
            state.params, batch
        )
        fisher, _ = empirical_fisher(
            lambda p: log_likelihood(p, batch), has_aux=True, normalize=True
        )(state.params)

        predict_cov = state.cov + transition_cov
        predict_cov_inv = torch.cholesky_inverse(torch.linalg.cholesky(predict_cov))
        update_cov_inv = predict_cov_inv - lr * fisher
        update_cov = torch.cholesky_inverse(torch.linalg.cholesky(update_cov_inv))

        mu_raveled, mu_unravel_f = tree_ravel(state.params)
        update_mean = mu_raveled + lr * update_cov @ tree_ravel(grad)[0]
        update_mean = mu_unravel_f(update_mean)

    if inplace:
        tree_insert_(state.params, update_mean)
        tree_insert_(state.cov, update_cov)
        tree_insert_(state.log_likelihood, log_liks.mean().detach())
        return state._replace(aux=aux)

    return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux)

posteriors.ekf.dense_fisher.sample(state, sample_shape=torch.Size([])) 𝞡

Single sample from Multivariate Normal distribution over parameters.

Parameters:

Name Type Description Default
state EKFDenseState

State encoding mean and covariance.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Multivariate Normal distribution.

Source code in posteriors/ekf/dense_fisher.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def sample(
    state: EKFDenseState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
    """Single sample from Multivariate Normal distribution over parameters.

    Args:
        state: State encoding mean and covariance.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Multivariate Normal distribution.
    """
    mean_flat, unravel_func = tree_ravel(state.params)

    samples = torch.distributions.MultivariateNormal(
        loc=mean_flat,
        covariance_matrix=state.cov,
        validate_args=False,
    ).sample(sample_shape)

    samples = torch.vmap(unravel_func)(samples)
    return samples