Skip to content

EKF Diagonal Fisher𝞡

posteriors.ekf.diag_fisher.build(log_likelihood, lr, transition_sd=0.0, per_sample=False, init_sds=1.0) 𝞡

Builds a transform to implement an extended Kalman Filter update.

EKF applies an online update to a (diagonal) Gaussian posterior over the parameters.

The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F_d(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F_d(μ)\) is the diagonal empirical Fisher information matrix at \(μ\) for data \(y\). Completing the square regains a diagonal Normal distribution over the parameters.

For more information on extended Kalman filtering as well as an equivalence to (online) natural gradient descent see Ollivier, 2019.

Parameters:

Name Type Description Default
log_likelihood LogProbFn

Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call.

required
lr float

Inverse temperature of the update, which behaves like a learning rate.

required
transition_sd float

Standard deviation of the transition noise, to additively inflate the diagonal covariance before the update.

0.0
per_sample bool

If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample.

False
init_sds TensorTree | float

Initial square-root diagonal of the covariance matrix of the Normal distribution. Can be tree like params or scalar.

1.0

Returns:

Type Description
Transform

Diagonal EKF transform instance.

Source code in posteriors/ekf/diag_fisher.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def build(
    log_likelihood: LogProbFn,
    lr: float,
    transition_sd: float = 0.0,
    per_sample: bool = False,
    init_sds: TensorTree | float = 1.0,
) -> Transform:
    """Builds a transform to implement an extended Kalman Filter update.

    EKF applies an online update to a (diagonal) Gaussian posterior over the parameters.

    The approximate Bayesian update is based on the linearization
    $$
    \\log p(θ | y) ≈ \\log p(θ) +  ε g(μ)ᵀ(θ - μ) +  \\frac12 ε (θ - μ)^T F_d(μ) (θ - μ)
    $$
    where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
    (or equivalently the likelihood inverse temperature),
    $g(μ)$ is the gradient of the log likelihood at μ and $F_d(μ)$ is the diagonal
    empirical Fisher information matrix at $μ$ for data $y$. Completing the square
    regains a diagonal Normal distribution over the parameters.

    For more information on extended Kalman filtering as well as an equivalence
    to (online) natural gradient descent see [Ollivier, 2019](https://arxiv.org/abs/1703.00209).

    Args:
        log_likelihood: Function that takes parameters and input batch and
            returns the log-likelihood value as well as auxiliary information,
            e.g. from the model call.
        lr: Inverse temperature of the update, which behaves like a learning rate.
        transition_sd: Standard deviation of the transition noise, to additively
            inflate the diagonal covariance before the update.
        per_sample: If True, then log_likelihood is assumed to return a vector of
            log likelihoods for each sample in the batch. If False, then log_likelihood
            is assumed to return a scalar log likelihood for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_likelihood to be per sample.
        init_sds: Initial square-root diagonal of the covariance matrix
            of the Normal distribution. Can be tree like params or scalar.

    Returns:
        Diagonal EKF transform instance.
    """
    init_fn = partial(init, init_sds=init_sds)
    update_fn = partial(
        update,
        log_likelihood=log_likelihood,
        lr=lr,
        transition_sd=transition_sd,
        per_sample=per_sample,
    )
    return Transform(init_fn, update_fn)

posteriors.ekf.diag_fisher.EKFDiagState 𝞡

Bases: NamedTuple

State encoding a diagonal Normal distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the Normal distribution.

sd_diag TensorTree

Square-root diagonal of the covariance matrix of the Normal distribution.

log_likelihood Tensor

Log likelihood of the data given the parameters.

aux Any

Auxiliary information from the log_likelihood call.

Source code in posteriors/ekf/diag_fisher.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class EKFDiagState(NamedTuple):
    """State encoding a diagonal Normal distribution over parameters.

    Attributes:
        params: Mean of the Normal distribution.
        sd_diag: Square-root diagonal of the covariance matrix of the
            Normal distribution.
        log_likelihood: Log likelihood of the data given the parameters.
        aux: Auxiliary information from the log_likelihood call.
    """

    params: TensorTree
    sd_diag: TensorTree
    log_likelihood: torch.Tensor = torch.tensor([])
    aux: Any = None

posteriors.ekf.diag_fisher.init(params, init_sds=1.0) 𝞡

Initialise diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
params TensorTree

Initial mean of the Normal distribution.

required
init_sds TensorTree | float

Initial square-root diagonal of the covariance matrix of the Normal distribution. Can be tree like params or scalar.

1.0

Returns:

Type Description
EKFDiagState

Initial EKFDiagState.

Source code in posteriors/ekf/diag_fisher.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def init(
    params: TensorTree,
    init_sds: TensorTree | float = 1.0,
) -> EKFDiagState:
    """Initialise diagonal Normal distribution over parameters.

    Args:
        params: Initial mean of the Normal distribution.
        init_sds: Initial square-root diagonal of the covariance matrix
            of the Normal distribution. Can be tree like params or scalar.

    Returns:
        Initial EKFDiagState.
    """
    if is_scalar(init_sds):
        init_sds = tree_map(
            lambda x: torch.full_like(x, init_sds, requires_grad=x.requires_grad),
            params,
        )

    return EKFDiagState(params, init_sds)

posteriors.ekf.diag_fisher.update(state, batch, log_likelihood, lr, transition_sd=0.0, per_sample=False, inplace=False) 𝞡

Applies an extended Kalman Filter update to the diagonal Normal distribution. The approximate Bayesian update is based on the linearization $$ \log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F_d(μ) (θ - μ) $$ where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate (or equivalently the likelihood inverse temperature), \(g(μ)\) is the gradient of the log likelihood at μ and \(F_d(μ)\) is the diagonal empirical Fisher information matrix at \(μ\) for data \(y\). Completing the square regains a diagonal Normal distribution over the parameters.

Parameters:

Name Type Description Default
state EKFDiagState

Current state.

required
batch Any

Input data to log_likelihood.

required
log_likelihood LogProbFn

Function that takes parameters and input batch and returns the log-likelihood value as well as auxiliary information, e.g. from the model call.

required
lr float

Inverse temperature of the update, which behaves like a learning rate.

required
transition_sd float

Standard deviation of the transition noise, to additively inflate the diagonal covariance before the update.

0.0
per_sample bool

If True, then log_likelihood is assumed to return a vector of log likelihoods for each sample in the batch. If False, then log_likelihood is assumed to return a scalar log likelihood for the whole batch, in this case torch.func.vmap will be called, this is typically slower than directly writing log_likelihood to be per sample.

False
inplace bool

Whether to update the state parameters in-place.

False

Returns:

Type Description
EKFDiagState

Updated EKFDiagState.

Source code in posteriors/ekf/diag_fisher.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def update(
    state: EKFDiagState,
    batch: Any,
    log_likelihood: LogProbFn,
    lr: float,
    transition_sd: float = 0.0,
    per_sample: bool = False,
    inplace: bool = False,
) -> EKFDiagState:
    """Applies an extended Kalman Filter update to the diagonal Normal distribution.
    The approximate Bayesian update is based on the linearization
    $$
    \\log p(θ | y) ≈ \\log p(θ) +  ε g(μ)ᵀ(θ - μ) +  \\frac12 ε (θ - μ)^T F_d(μ) (θ - μ)
    $$
    where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
    (or equivalently the likelihood inverse temperature),
    $g(μ)$ is the gradient of the log likelihood at μ and $F_d(μ)$ is the diagonal
    empirical Fisher information matrix at $μ$ for data $y$. Completing the square
    regains a diagonal Normal distribution over the parameters.

    Args:
        state: Current state.
        batch: Input data to log_likelihood.
        log_likelihood: Function that takes parameters and input batch and
            returns the log-likelihood value as well as auxiliary information,
            e.g. from the model call.
        lr: Inverse temperature of the update, which behaves like a learning rate.
        transition_sd: Standard deviation of the transition noise, to additively
            inflate the diagonal covariance before the update.
        per_sample: If True, then log_likelihood is assumed to return a vector of
            log likelihoods for each sample in the batch. If False, then log_likelihood
            is assumed to return a scalar log likelihood for the whole batch, in this
            case torch.func.vmap will be called, this is typically slower than
            directly writing log_likelihood to be per sample.
        inplace: Whether to update the state parameters in-place.

    Returns:
        Updated EKFDiagState.
    """

    if not per_sample:
        log_likelihood = per_samplify(log_likelihood)

    predict_sd_diag = flexi_tree_map(
        lambda x: (x**2 + transition_sd**2) ** 0.5, state.sd_diag, inplace=inplace
    )
    with torch.no_grad(), CatchAuxError():
        log_liks, aux = log_likelihood(state.params, batch)
        jac, _ = jacrev(log_likelihood, has_aux=True)(state.params, batch)
        grad = tree_map(lambda x: x.mean(0), jac)
        diag_lik_hessian_approx = tree_map(lambda x: -(x**2).mean(0), jac)

    update_sd_diag = flexi_tree_map(
        lambda sig, h: (sig**-2 - lr * h) ** -0.5,
        predict_sd_diag,
        diag_lik_hessian_approx,
        inplace=inplace,
    )
    update_mean = flexi_tree_map(
        lambda mu, sig, g: mu + sig**2 * lr * g,
        state.params,
        update_sd_diag,
        grad,
        inplace=inplace,
    )

    if inplace:
        tree_insert_(state.log_likelihood, log_liks.mean().detach())
        return state._replace(aux=aux)

    return EKFDiagState(update_mean, update_sd_diag, log_liks.mean().detach(), aux)

posteriors.ekf.diag_fisher.sample(state, sample_shape=torch.Size([])) 𝞡

Single sample from diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
state EKFDiagState

State encoding mean and standard deviations.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Normal distribution.

Source code in posteriors/ekf/diag_fisher.py
183
184
185
186
187
188
189
190
191
192
193
194
195
def sample(
    state: EKFDiagState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
    """Single sample from diagonal Normal distribution over parameters.

    Args:
        state: State encoding mean and standard deviations.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Normal distribution.
    """
    return diag_normal_sample(state.params, state.sd_diag, sample_shape=sample_shape)