Skip to content

VI Dense𝞡

posteriors.vi.dense.build(log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, init_L=1.0) 𝞡

Builds a transform for variational inference with a Normal distribution over parameters.

Find \(\mu\) and \(\Sigma\) that mimimize \(\text{KL}(N(θ| \mu, \Sigma) || p_T(θ))\) where \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data.

For more information on variational inference see Blei et al, 2017.

Parameters:

Name Type Description Default
log_posterior Callable[[TensorTree, Any], float]

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from Roeder et al.

True
init_L Tensor | float

Initial lower triangular matrix \(L\) satisfying \(LL^T\) = \(\Sigma\).

1.0

Returns:

Type Description
Transform

Dense VI transform instance.

Source code in posteriors/vi/dense.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def build(
    log_posterior: Callable[[TensorTree, Any], float],
    optimizer: torchopt.base.GradientTransformation,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
    init_L: torch.Tensor | float = 1.0,
) -> Transform:
    """Builds a transform for variational inference with a Normal
    distribution over parameters.

    Find $\\mu$ and $\\Sigma$ that mimimize $\\text{KL}(N(θ| \\mu, \\Sigma) || p_T(θ))$
    where $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data.

    For more information on variational inference see [Blei et al, 2017](https://arxiv.org/abs/1601.00670).

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from [Roeder et al](https://arxiv.org/abs/1703.09194).
        init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$.

    Returns:
        Dense VI transform instance.
    """
    init_fn = partial(init, optimizer=optimizer, init_L=init_L)
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        optimizer=optimizer,
        temperature=temperature,
        n_samples=n_samples,
        stl=stl,
    )
    return Transform(init_fn, update_fn)

posteriors.vi.dense.VIDenseState 𝞡

Bases: NamedTuple

State encoding a diagonal Normal variational distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the variational distribution.

L_factor Tensor

Flat representation of the nonzero values of the lower triangular matrix \(L\) satisfying \(LL^T\) = \(\Sigma\), where \(\Sigma\) is the covariance matrix of the variational distribution.

opt_state OptState

TorchOpt state storing optimizer data for updating the variational parameters.

nelbo tensor

Negative evidence lower bound (lower is better).

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/vi/dense.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class VIDenseState(NamedTuple):
    """State encoding a diagonal Normal variational distribution over parameters.

    Attributes:
        params: Mean of the variational distribution.
        L_factor: Flat representation of the nonzero values of the lower
            triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$, where $\\Sigma$
            is the covariance matrix of the variational distribution.
        opt_state: TorchOpt state storing optimizer data for updating the
            variational parameters.
        nelbo: Negative evidence lower bound (lower is better).
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    L_factor: torch.Tensor
    opt_state: torchopt.typing.OptState
    nelbo: torch.tensor = torch.tensor([])
    aux: Any = None

posteriors.vi.dense.init(params, optimizer, init_L=1.0) 𝞡

Initialise diagonal Normal variational distribution over parameters.

optimizer.init will be called on flattened variational parameters so hyperparameters such as learning rate need to pre-specified through TorchOpt's functional API:

import torchopt

optimizer = torchopt.adam(lr=1e-2)
vi_state = init(init_mean, optimizer)

It's assumed maximize=False for the optimizer, so that we minimize the NELBO.

Parameters:

Name Type Description Default
params TensorTree

Initial mean of the variational distribution.

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
init_L Tensor | float

Initial lower triangular matrix \(L\) satisfying \(LL^T\) = \(\Sigma\), where \(\Sigma\) is the covariance matrix of the variational distribution.

1.0

Returns:

Type Description
VIDenseState

Initial DenseVIState.

Source code in posteriors/vi/dense.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def init(
    params: TensorTree,
    optimizer: torchopt.base.GradientTransformation,
    init_L: torch.Tensor | float = 1.0,
) -> VIDenseState:
    """Initialise diagonal Normal variational distribution over parameters.

    optimizer.init will be called on flattened variational parameters so hyperparameters
    such as learning rate need to pre-specified through TorchOpt's functional API:

    ```
    import torchopt

    optimizer = torchopt.adam(lr=1e-2)
    vi_state = init(init_mean, optimizer)
    ```

    It's assumed maximize=False for the optimizer, so that we minimize the NELBO.

    Args:
        params: Initial mean of the variational distribution.
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$,
            where $\\Sigma$ is the covariance matrix of the variational distribution.

    Returns:
        Initial DenseVIState.
    """

    num_params = tree_size(params)
    if is_scalar(init_L):
        init_L = init_L * torch.eye(num_params, requires_grad=True)

    init_L = L_to_flat(init_L)
    opt_state = optimizer.init([params, init_L])
    return VIDenseState(params, init_L, opt_state)

posteriors.vi.dense.update(state, batch, log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, inplace=False) 𝞡

Updates the variational parameters to minimize the NELBO.

Parameters:

Name Type Description Default
state VIDenseState

Current state.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
VIDenseState

Updated DenseVIState.

Source code in posteriors/vi/dense.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def update(
    state: VIDenseState,
    batch: Any,
    log_posterior: LogProbFn,
    optimizer: torchopt.base.GradientTransformation,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
    inplace: bool = False,
) -> VIDenseState:
    """Updates the variational parameters to minimize the NELBO.

    Args:
        state: Current state.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).
        inplace: Whether to modify state in place.

    Returns:
        Updated DenseVIState.
    """

    def nelbo_L_factor(m, L_flat):
        return nelbo(m, L_flat, batch, log_posterior, temperature, n_samples, stl)

    with torch.no_grad(), CatchAuxError():
        nelbo_grads, (nelbo_val, aux) = grad_and_value(
            nelbo_L_factor, argnums=(0, 1), has_aux=True
        )(state.params, state.L_factor)

    updates, opt_state = optimizer.update(
        nelbo_grads,
        state.opt_state,
        params=[state.params, state.L_factor],
        inplace=inplace,
    )
    mean, L_factor = torchopt.apply_updates(
        (state.params, state.L_factor), updates, inplace=inplace
    )

    if inplace:
        tree_insert_(state.nelbo, nelbo_val.detach())
        return state._replace(aux=aux)

    return VIDenseState(mean, L_factor, opt_state, nelbo_val.detach(), aux)

posteriors.vi.dense.nelbo(mean, L_factor, batch, log_posterior, temperature=1.0, n_samples=1, stl=True) 𝞡

Returns the negative evidence lower bound (NELBO) for a Normal variational distribution over the parameters of a model.

Monte Carlo estimate with n_samples from q. $$ \text{NELBO} = - 𝔼_{q(θ)}[\log p(y|x, θ) + \log p(θ) - \log q(θ) * T]) $$ for temperature \(T\).

log_posterior expects to take parameters and input batch and return a scalar as well as a TensorTree of any auxiliary information:

log_posterior_eval, aux = log_posterior(params, batch)

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
mean dict

Mean of the variational distribution.

required
L_factor Tensor

Flat representation of the nonzero values of the lower triangular matrix \(L\) satisfying \(LL^T\) = \(\Sigma\), where \(\Sigma\) is the covariance matrix of the variational distribution.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True

Returns:

Type Description
Tuple[float, Any]

The sampled approximate NELBO averaged over the batch.

Source code in posteriors/vi/dense.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def nelbo(
    mean: dict,
    L_factor: torch.Tensor,
    batch: Any,
    log_posterior: LogProbFn,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
) -> Tuple[float, Any]:
    """Returns the negative evidence lower bound (NELBO) for a Normal
    variational distribution over the parameters of a model.

    Monte Carlo estimate with `n_samples` from q.
    $$
    \\text{NELBO} = - 𝔼_{q(θ)}[\\log p(y|x, θ) + \\log p(θ) - \\log q(θ) * T])
    $$
    for temperature $T$.

    `log_posterior` expects to take parameters and input batch and return a scalar
    as well as a TensorTree of any auxiliary information:

    ```
    log_posterior_eval, aux = log_posterior(params, batch)
    ```

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        mean: Mean of the variational distribution.
        L_factor: Flat representation of the nonzero values of the lower
            triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$, where $\\Sigma$
            is the covariance matrix of the variational distribution.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).

    Returns:
        The sampled approximate NELBO averaged over the batch.
    """

    mean_flat, unravel_func = tree_ravel(mean)
    L = L_from_flat(L_factor)
    cov = L @ L.T
    dist = torch.distributions.MultivariateNormal(
        loc=mean_flat,
        covariance_matrix=cov,
        validate_args=False,
    )

    sampled_params = dist.rsample((n_samples,))
    sampled_params_tree = torch.vmap(lambda s: unravel_func(s))(sampled_params)

    if stl:
        mean_flat.detach()
        L = L_from_flat(L_factor.detach())
        cov = L @ L.T
        # Redefine distribution to sample from after stl
        dist = torch.distributions.MultivariateNormal(
            loc=mean_flat,
            covariance_matrix=cov,
            validate_args=False,
        )

    # Don't use vmap for single sample, since vmap doesn't work with lots of models
    if n_samples == 1:
        single_param = tree_map(lambda x: x[0], sampled_params_tree)
        single_param_flat, _ = tree_ravel(single_param)
        log_p, aux = log_posterior(single_param, batch)
        log_q = dist.log_prob(single_param_flat)

    else:
        log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params_tree, batch)
        log_q = dist.log_prob(sampled_params)

    return -(log_p - log_q * temperature).mean(), aux

posteriors.vi.dense.sample(state, sample_shape=torch.Size([])) 𝞡

Single sample from Normal distribution over parameters.

Parameters:

Name Type Description Default
state VIDenseState

State encoding mean and covariance matrix.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Normal distribution.

Source code in posteriors/vi/dense.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def sample(
    state: VIDenseState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
    """Single sample from Normal distribution over parameters.

    Args:
        state: State encoding mean and covariance matrix.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Normal distribution.
    """

    mean_flat, unravel_func = tree_ravel(state.params)
    L = L_from_flat(state.L_factor)
    cov = L @ L.T

    samples = torch.distributions.MultivariateNormal(
        loc=mean_flat,
        covariance_matrix=cov,
        validate_args=False,
    ).rsample(sample_shape)

    samples = torch.vmap(unravel_func)(samples)
    return samples