Skip to content

VI Diag𝞡

posteriors.vi.diag.build(log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, init_log_sds=0.0) 𝞡

Builds a transform for variational inference with a diagonal Normal distribution over parameters.

Find \(\mu\) and diagonal \(\Sigma\) that mimimize \(\text{KL}(N(θ| \mu, \Sigma) || p_T(θ))\) where \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data.

For more information on variational inference see Blei et al, 2017.

Parameters:

Name Type Description Default
log_posterior Callable[[TensorTree, Any], float]

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True
init_log_sds TensorTree | float

Initial log of the square-root diagonal of the covariance matrix of the variational distribution. Can be a tree matching params or scalar.

0.0

Returns:

Type Description
Transform

Diagonal VI transform instance.

Source code in posteriors/vi/diag.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def build(
    log_posterior: Callable[[TensorTree, Any], float],
    optimizer: torchopt.base.GradientTransformation,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
    init_log_sds: TensorTree | float = 0.0,
) -> Transform:
    """Builds a transform for variational inference with a diagonal Normal
    distribution over parameters.

    Find $\\mu$ and diagonal $\\Sigma$ that mimimize $\\text{KL}(N(θ| \\mu, \\Sigma) || p_T(θ))$
    where $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data.

    For more information on variational inference see [Blei et al, 2017](https://arxiv.org/abs/1601.00670).

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).
        init_log_sds: Initial log of the square-root diagonal of the covariance matrix
            of the variational distribution. Can be a tree matching params or scalar.

    Returns:
        Diagonal VI transform instance.
    """
    init_fn = partial(init, optimizer=optimizer, init_log_sds=init_log_sds)
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        optimizer=optimizer,
        temperature=temperature,
        n_samples=n_samples,
        stl=stl,
    )
    return Transform(init_fn, update_fn)

posteriors.vi.diag.VIDiagState 𝞡

Bases: NamedTuple

State encoding a diagonal Normal variational distribution over parameters.

Attributes:

Name Type Description
params TensorTree

Mean of the variational distribution.

log_sd_diag TensorTree

Log of the square-root diagonal of the covariance matrix of the variational distribution.

opt_state OptState

TorchOpt state storing optimizer data for updating the variational parameters.

nelbo tensor

Negative evidence lower bound (lower is better).

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/vi/diag.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class VIDiagState(NamedTuple):
    """State encoding a diagonal Normal variational distribution over parameters.

    Attributes:
        params: Mean of the variational distribution.
        log_sd_diag: Log of the square-root diagonal of the covariance matrix of the
            variational distribution.
        opt_state: TorchOpt state storing optimizer data for updating the
            variational parameters.
        nelbo: Negative evidence lower bound (lower is better).
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    log_sd_diag: TensorTree
    opt_state: torchopt.typing.OptState
    nelbo: torch.tensor = torch.tensor([])
    aux: Any = None

posteriors.vi.diag.init(params, optimizer, init_log_sds=0.0) 𝞡

Initialise diagonal Normal variational distribution over parameters.

optimizer.init will be called on flattened variational parameters so hyperparameters such as learning rate need to pre-specified through TorchOpt's functional API:

import torchopt

optimizer = torchopt.adam(lr=1e-2)
vi_state = init(init_mean, optimizer)

It's assumed maximize=False for the optimizer, so that we minimize the NELBO.

Parameters:

Name Type Description Default
params TensorTree

Initial mean of the variational distribution.

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
init_log_sds TensorTree | float

Initial log of the square-root diagonal of the covariance matrix of the variational distribution. Can be a tree matching params or scalar.

0.0

Returns:

Type Description
VIDiagState

Initial DiagVIState.

Source code in posteriors/vi/diag.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def init(
    params: TensorTree,
    optimizer: torchopt.base.GradientTransformation,
    init_log_sds: TensorTree | float = 0.0,
) -> VIDiagState:
    """Initialise diagonal Normal variational distribution over parameters.

    optimizer.init will be called on flattened variational parameters so hyperparameters
    such as learning rate need to pre-specified through TorchOpt's functional API:

    ```
    import torchopt

    optimizer = torchopt.adam(lr=1e-2)
    vi_state = init(init_mean, optimizer)
    ```

    It's assumed maximize=False for the optimizer, so that we minimize the NELBO.

    Args:
        params: Initial mean of the variational distribution.
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        init_log_sds: Initial log of the square-root diagonal of the covariance matrix
            of the variational distribution. Can be a tree matching params or scalar.

    Returns:
        Initial DiagVIState.
    """
    if is_scalar(init_log_sds):
        init_log_sds = tree_map(
            lambda x: torch.full_like(x, init_log_sds, requires_grad=x.requires_grad),
            params,
        )

    opt_state = optimizer.init([params, init_log_sds])
    return VIDiagState(params, init_log_sds, opt_state)

posteriors.vi.diag.update(state, batch, log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, inplace=False) 𝞡

Updates the variational parameters to minimize the NELBO.

Parameters:

Name Type Description Default
state VIDiagState

Current state.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
optimizer GradientTransformation

TorchOpt functional optimizer for updating the variational parameters. Make sure to use lower case like torchopt.adam()

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
VIDiagState

Updated DiagVIState.

Source code in posteriors/vi/diag.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def update(
    state: VIDiagState,
    batch: Any,
    log_posterior: LogProbFn,
    optimizer: torchopt.base.GradientTransformation,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
    inplace: bool = False,
) -> VIDiagState:
    """Updates the variational parameters to minimize the NELBO.

    Args:
        state: Current state.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        optimizer: TorchOpt functional optimizer for updating the variational
            parameters. Make sure to use lower case like torchopt.adam()
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).
        inplace: Whether to modify state in place.

    Returns:
        Updated DiagVIState.
    """

    def nelbo_log_sd(m, lsd):
        sd_diag = tree_map(torch.exp, lsd)
        return nelbo(m, sd_diag, batch, log_posterior, temperature, n_samples, stl)

    with torch.no_grad(), CatchAuxError():
        nelbo_grads, (nelbo_val, aux) = grad_and_value(
            nelbo_log_sd, argnums=(0, 1), has_aux=True
        )(state.params, state.log_sd_diag)

    updates, opt_state = optimizer.update(
        nelbo_grads,
        state.opt_state,
        params=[state.params, state.log_sd_diag],
        inplace=inplace,
    )
    mean, log_sd_diag = torchopt.apply_updates(
        (state.params, state.log_sd_diag), updates, inplace=inplace
    )

    if inplace:
        tree_insert_(state.nelbo, nelbo_val.detach())
        return state._replace(aux=aux)

    return VIDiagState(mean, log_sd_diag, opt_state, nelbo_val.detach(), aux)

posteriors.vi.diag.nelbo(mean, sd_diag, batch, log_posterior, temperature=1.0, n_samples=1, stl=True) 𝞡

Returns the negative evidence lower bound (NELBO) for a diagonal Normal variational distribution over the parameters of a model.

Monte Carlo estimate with n_samples from q. $$ \text{NELBO} = - 𝔼_{q(θ)}[\log p(y|x, θ) + \log p(θ) - \log q(θ) * T]) $$ for temperature \(T\).

log_posterior expects to take parameters and input batch and return a scalar as well as a TensorTree of any auxiliary information:

log_posterior_eval, aux = log_posterior(params, batch)

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
mean dict

Mean of the variational distribution.

required
sd_diag dict

Square-root diagonal of the covariance matrix of the variational distribution.

required
batch Any

Input data to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior (which can be unnormalised).

required
temperature float

Temperature to rescale (divide) log_posterior.

1.0
n_samples int

Number of samples to use for Monte Carlo estimate.

1
stl bool

Whether to use the stick-the-landing estimator from (Roeder et al](https://arxiv.org/abs/1703.09194).

True

Returns:

Type Description
Tuple[float, Any]

The sampled approximate NELBO averaged over the batch.

Source code in posteriors/vi/diag.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def nelbo(
    mean: dict,
    sd_diag: dict,
    batch: Any,
    log_posterior: LogProbFn,
    temperature: float = 1.0,
    n_samples: int = 1,
    stl: bool = True,
) -> Tuple[float, Any]:
    """Returns the negative evidence lower bound (NELBO) for a diagonal Normal
    variational distribution over the parameters of a model.

    Monte Carlo estimate with `n_samples` from q.
    $$
    \\text{NELBO} = - 𝔼_{q(θ)}[\\log p(y|x, θ) + \\log p(θ) - \\log q(θ) * T])
    $$
    for temperature $T$.

    `log_posterior` expects to take parameters and input batch and return a scalar
    as well as a TensorTree of any auxiliary information:

    ```
    log_posterior_eval, aux = log_posterior(params, batch)
    ```

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        mean: Mean of the variational distribution.
        sd_diag: Square-root diagonal of the covariance matrix of the
            variational distribution.
        batch: Input data to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior (which can be unnormalised).
        temperature: Temperature to rescale (divide) log_posterior.
        n_samples: Number of samples to use for Monte Carlo estimate.
        stl: Whether to use the stick-the-landing estimator
            from (Roeder et al](https://arxiv.org/abs/1703.09194).

    Returns:
        The sampled approximate NELBO averaged over the batch.
    """
    sampled_params = diag_normal_sample(mean, sd_diag, sample_shape=(n_samples,))
    if stl:
        mean = tree_map(lambda x: x.detach(), mean)
        sd_diag = tree_map(lambda x: x.detach(), sd_diag)

    # Don't use vmap for single sample, since vmap doesn't work with lots of models
    if n_samples == 1:
        single_param = tree_map(lambda x: x[0], sampled_params)
        log_p, aux = log_posterior(single_param, batch)
        log_q = diag_normal_log_prob(single_param, mean, sd_diag)
    else:
        log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params, batch)
        log_q = vmap(diag_normal_log_prob, (0, None, None))(
            sampled_params, mean, sd_diag
        )
    return -(log_p - log_q * temperature).mean(), aux

posteriors.vi.diag.sample(state, sample_shape=torch.Size([])) 𝞡

Single sample from diagonal Normal distribution over parameters.

Parameters:

Name Type Description Default
state VIDiagState

State encoding mean and log standard deviations.

required
sample_shape Size

Shape of the desired samples.

Size([])

Returns:

Type Description
TensorTree

Sample(s) from Normal distribution.

Source code in posteriors/vi/diag.py
239
240
241
242
243
244
245
246
247
248
249
250
def sample(state: VIDiagState, sample_shape: torch.Size = torch.Size([])) -> TensorTree:
    """Single sample from diagonal Normal distribution over parameters.

    Args:
        state: State encoding mean and log standard deviations.
        sample_shape: Shape of the desired samples.

    Returns:
        Sample(s) from Normal distribution.
    """
    sd_diag = tree_map(torch.exp, state.log_sd_diag)
    return diag_normal_sample(state.params, sd_diag, sample_shape=sample_shape)