Skip to content

VI Diag𝞡

posteriors.vi.diag.build(log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, init_log_sds=0.0) 𝞡

Builds a transform for variational inference with a diagonal Normal distribution over parameters.

Find \(\mu\) and diagonal \(\Sigma\) that mimimize \(\text{KL}(N(θ| \mu, \Sigma) || p_T(θ))\) where \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data.

For more information on variational inference see Blei et al, 2017.

Parameters:

Name Type Description Default
log_posterior Callable[[TensorTree, Any], float]

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
temperature float | Schedule

Temperature to rescale (divide) log_posterior. Scalar or schedule (callable taking step index, returning scalar).

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from Roeder et al.

True
init_log_sds TensorTree | float

Initial log of the square-root diagonal of the covariance matrix of the variational distribution. Can be a tree matching params or scalar.

0.0

Returns:

Type Description
Transform

Diagonal VI transform instance.

Source code in posteriors/vi/diag.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def build(
    log_posterior: Callable[[TensorTree, Any], float],
    optimizer: torchopt.base.GradientTransformation,
    temperature: float | Schedule = 1.0,
    n_samples: int = 1,
    stl: bool = True,
    init_log_sds: TensorTree | float = 0.0,
) -> Transform:
    """Builds a transform for variational inference with a diagonal Normal
    distribution over parameters.

    Find $\\mu$ and diagonal $\\Sigma$ that mimimize $\\text{KL}(N(θ| \\mu, \\Sigma) || p_T(θ))$
    where $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data.

    For more information on variational inference see [Blei et al, 2017](https://arxiv.org/abs/1601.00670).

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        temperature: Temperature to rescale (divide) log_posterior.
            Scalar or schedule (callable taking step index, returning scalar).
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from [Roeder et al](https://arxiv.org/abs/1703.09194).
        init_log_sds: Initial log of the square-root diagonal of the covariance matrix
            of the variational distribution. Can be a tree matching params or scalar.

    Returns:
        Diagonal VI transform instance.
    """
    init_fn = partial(init, optimizer=optimizer, init_log_sds=init_log_sds)
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        optimizer=optimizer,
        temperature=temperature,
        n_samples=n_samples,
        stl=stl,
    )
    return Transform(init_fn, update_fn)

posteriors.vi.diag.VIDiagState 𝞡

Bases: TensorClass['frozen']

State encoding a diagonal Normal variational distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the variational distribution.

log_sd_diag TensorTree

Log of the square-root diagonal of the covariance matrix of the variational distribution.

opt_state OptState

TorchOpt state storing optimizer data for updating the variational parameters.

nelbo Tensor

Negative evidence lower bound (lower is better).

step Tensor

Current step count.

Source code in posteriors/vi/diag.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class VIDiagState(TensorClass["frozen"]):
    """State encoding a diagonal Normal variational distribution over parameters.

    Attributes:
        params: Mean of the variational distribution.
        log_sd_diag: Log of the square-root diagonal of the covariance matrix of the
            variational distribution.
        opt_state: TorchOpt state storing optimizer data for updating the
            variational parameters.
        nelbo: Negative evidence lower bound (lower is better).
        step: Current step count.
    """

    params: TensorTree
    log_sd_diag: TensorTree
    opt_state: torchopt.typing.OptState
    nelbo: torch.Tensor = torch.tensor([])
    step: torch.Tensor = torch.tensor(0)

posteriors.vi.diag.init(params, optimizer, init_log_sds=0.0) 𝞡

Initialise diagonal Normal variational distribution over parameters.

optimizer.init will be called on flattened variational parameters so hyperparameters such as learning rate need to pre-specified through TorchOpt's functional API:

import torchopt

optimizer = torchopt.adam(lr=1e-2)
vi_state = init(init_mean, optimizer)

It's assumed maximize=False for the optimizer, so that we minimize the NELBO.

Parameters:

Name Type Description Default
params TensorTree

Initial mean of the variational distribution.

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
init_log_sds TensorTree | float

Initial log of the square-root diagonal of the covariance matrix of the variational distribution. Can be a tree matching params or scalar.

0.0

Returns:

Type Description
VIDiagState

Initial DiagVIState.

Source code in posteriors/vi/diag.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def init(
    params: TensorTree,
    optimizer: torchopt.base.GradientTransformation,
    init_log_sds: TensorTree | float = 0.0,
) -> VIDiagState:
    """Initialise diagonal Normal variational distribution over parameters.

    optimizer.init will be called on flattened variational parameters so hyperparameters
    such as learning rate need to pre-specified through TorchOpt's functional API:

    ```
    import torchopt

    optimizer = torchopt.adam(lr=1e-2)
    vi_state = init(init_mean, optimizer)
    ```

    It's assumed maximize=False for the optimizer, so that we minimize the NELBO.

    Args:
        params: Initial mean of the variational distribution.
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        init_log_sds: Initial log of the square-root diagonal of the covariance matrix
            of the variational distribution. Can be a tree matching params or scalar.

    Returns:
        Initial DiagVIState.
    """
    if is_scalar(init_log_sds):
        init_log_sds = tree_map(
            lambda x: torch.full_like(x, init_log_sds, requires_grad=x.requires_grad),
            params,
        )

    opt_state = optimizer.init([params, init_log_sds])
    return VIDiagState(params, init_log_sds, opt_state)

posteriors.vi.diag.update(state, batch, log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, inplace=False) 𝞡

Updates the variational parameters to minimize the NELBO.

Parameters:

Name Type Description Default
state VIDiagState

Current state.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
temperature float

Temperature to rescale (divide) log_posterior. Scalar or schedule (callable taking step index, returning scalar).

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
tuple[VIDiagState, TensorTree]

Updated DiagVIState and auxiliary information.

Source code in posteriors/vi/diag.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def update(
    state: VIDiagState,
    batch: Any,
    log_posterior: LogProbFn,
    optimizer: torchopt.base.GradientTransformation,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
    inplace: bool = False,
) -> tuple[VIDiagState, TensorTree]:
    """Updates the variational parameters to minimize the NELBO.

    Args:
        state: Current state.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        temperature: Temperature to rescale (divide) log_posterior.
            Scalar or schedule (callable taking step index, returning scalar).
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).
        inplace: Whether to modify state in place.

    Returns:
        Updated DiagVIState and auxiliary information.
    """
    temperature = temperature(state.step) if callable(temperature) else temperature

    def nelbo_log_sd(m, lsd):
        sd_diag = tree_map(torch.exp, lsd)
        return nelbo(m, sd_diag, batch, log_posterior, temperature, n_samples, stl)

    with torch.no_grad(), CatchAuxError():
        nelbo_grads, (nelbo_val, aux) = grad_and_value(
            nelbo_log_sd, argnums=(0, 1), has_aux=True
        )(state.params, state.log_sd_diag)

    updates, opt_state = optimizer.update(
        nelbo_grads,
        state.opt_state,
        params=[state.params, state.log_sd_diag],
        inplace=inplace,
    )
    mean, log_sd_diag = torchopt.apply_updates(
        (state.params, state.log_sd_diag), updates, inplace=inplace
    )

    if inplace:
        tree_insert_(state.nelbo, nelbo_val.detach())
        tree_insert_(state.step, state.step + 1)
        return state, aux

    return VIDiagState(
        mean, log_sd_diag, opt_state, nelbo_val.detach(), state.step + 1
    ), aux

posteriors.vi.diag.nelbo(mean, sd_diag, batch, log_posterior, temperature=1.0, n_samples=1, stl=True) 𝞡

Returns the negative evidence lower bound (NELBO) for a diagonal Normal variational distribution over the parameters of a model.

Monte Carlo estimate with n_samples from q. $$ \text{NELBO} = - 𝔼_{q(θ)}[\log p(y|x, θ) + \log p(θ) - \log q(θ) * T]) $$ for temperature \(T\).

log_posterior expects to take parameters and input batch and return a scalar as well as a TensorTree of any auxiliary information:

log_posterior_eval, aux = log_posterior(params, batch)

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
mean TensorTree

Mean of the variational distribution.

required
sd_diag TensorTree

Square-root diagonal of the covariance matrix of the variational distribution.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True

Returns:

Type Description
Tuple[float, Any]

The sampled approximate NELBO averaged over the batch.

Source code in posteriors/vi/diag.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def nelbo(
    mean: TensorTree,
    sd_diag: TensorTree,
    batch: Any,
    log_posterior: LogProbFn,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
) -> Tuple[float, Any]:
    """Returns the negative evidence lower bound (NELBO) for a diagonal Normal
    variational distribution over the parameters of a model.

    Monte Carlo estimate with `n_samples` from q.
    $$
    \\text{NELBO} = - 𝔼_{q(θ)}[\\log p(y|x, θ) + \\log p(θ) - \\log q(θ) * T])
    $$
    for temperature $T$.

    `log_posterior` expects to take parameters and input batch and return a scalar
    as well as a TensorTree of any auxiliary information:

    ```
    log_posterior_eval, aux = log_posterior(params, batch)
    ```

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        mean: Mean of the variational distribution.
        sd_diag: Square-root diagonal of the covariance matrix of the
            variational distribution.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).

    Returns:
        The sampled approximate NELBO averaged over the batch.
    """
    sampled_params = diag_normal_sample(mean, sd_diag, sample_shape=(n_samples,))
    if stl:
        mean = tree_map(lambda x: x.detach(), mean)
        sd_diag = tree_map(lambda x: x.detach(), sd_diag)

    # Don't use vmap for single sample, since vmap doesn't work with lots of models
    if n_samples == 1:
        single_param = tree_map(lambda x: x[0], sampled_params)
        log_p, aux = log_posterior(single_param, batch)
        log_q = diag_normal_log_prob(single_param, mean, sd_diag)
    else:
        log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params, batch)
        log_q = vmap(diag_normal_log_prob, (0, None, None))(
            sampled_params, mean, sd_diag
        )
    return -(log_p - log_q * temperature).mean(), aux

posteriors.vi.diag.sample(state, sample_shape=torch.Size([])) 𝞡

Single sample from diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
state VIDiagState

State encoding mean and log standard deviations.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Normal distribution.

Source code in posteriors/vi/diag.py
246
247
248
249
250
251
252
253
254
255
256
257
def sample(state: VIDiagState, sample_shape: torch.Size = torch.Size([])) -> TensorTree:
    """Single sample from diagonal Normal distribution over parameters.

    Args:
        state: State encoding mean and log standard deviations.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Normal distribution.
    """
    sd_diag = tree_map(torch.exp, state.log_sd_diag)
    return diag_normal_sample(state.params, sd_diag, sample_shape=sample_shape)