Skip to content

SGLD𝞡

posteriors.sgmcmc.sgld.build(log_posterior, lr, beta=0.0, temperature=1.0) 𝞡

Builds SGLD transform.

Algorithm from Welling and Teh, 2011: $$ θ_{t+1} = θ_t + ε \nabla \log p(θ_t, \text{batch}) + N(0, ε (2 - ε β) T \mathbb{I}) $$ for learning rate \(\epsilon\) and temperature \(T\).

Targets \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float | Schedule

Learning rate, scalar or schedule (callable taking step index, returning scalar).

required
beta float

Gradient noise coefficient (estimated variance).

0.0
temperature float | Schedule

Temperature of the sampling distribution. Scalar or schedule (callable taking step index, returning scalar).

1.0

Returns:

Type Description
Transform

SGLD transform (posteriors.types.Transform instance).

Source code in posteriors/sgmcmc/sgld.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def build(
    log_posterior: LogProbFn,
    lr: float | Schedule,
    beta: float = 0.0,
    temperature: float | Schedule = 1.0,
) -> Transform:
    """Builds SGLD transform.

    Algorithm from [Welling and Teh, 2011](https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf):
    $$
    θ_{t+1} = θ_t + ε \\nabla \\log p(θ_t, \\text{batch}) + N(0, ε  (2 - ε β) T \\mathbb{I})
    $$
    for learning rate $\\epsilon$ and temperature $T$.

    Targets $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate,
            scalar or schedule (callable taking step index, returning scalar).
        beta: Gradient noise coefficient (estimated variance).
        temperature: Temperature of the sampling distribution.
            Scalar or schedule (callable taking step index, returning scalar).

    Returns:
        SGLD transform (posteriors.types.Transform instance).
    """
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        lr=lr,
        beta=beta,
        temperature=temperature,
    )
    return Transform(init, update_fn)

posteriors.sgmcmc.sgld.SGLDState 𝞡

Bases: TensorClass['frozen']

State encoding params for SGLD.

Attributes:

Name Type Description
params TensorTree

Parameters.

log_posterior Tensor

Log posterior evaluation.

step Tensor

Current step count.

Source code in posteriors/sgmcmc/sgld.py
55
56
57
58
59
60
61
62
63
64
65
66
class SGLDState(TensorClass["frozen"]):
    """State encoding params for SGLD.

    Attributes:
        params: Parameters.
        log_posterior: Log posterior evaluation.
        step: Current step count.
    """

    params: TensorTree
    log_posterior: Tensor = torch.tensor(torch.nan)
    step: Tensor = torch.tensor(0)

posteriors.sgmcmc.sgld.init(params) 𝞡

Initialise SGLD.

Parameters:

Name Type Description Default
params TensorTree

Parameters for which to initialise.

required

Returns:

Type Description
SGLDState

Initial SGLDState.

Source code in posteriors/sgmcmc/sgld.py
69
70
71
72
73
74
75
76
77
78
79
def init(params: TensorTree) -> SGLDState:
    """Initialise SGLD.

    Args:
        params: Parameters for which to initialise.

    Returns:
        Initial SGLDState.
    """

    return SGLDState(params)

posteriors.sgmcmc.sgld.update(state, batch, log_posterior, lr, beta=0.0, temperature=1.0, inplace=False) 𝞡

Updates parameters for SGLD.

Update rule from Welling and Teh, 2011, see build for details.

Parameters:

Name Type Description Default
state SGLDState

SGLDState containing params.

required
batch Any

Data batch to be send to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float | Schedule

Learning rate, scalar or schedule (callable taking step index, returning scalar).

required
beta float

Gradient noise coefficient (estimated variance).

0.0
temperature float | Schedule

Temperature of the sampling distribution. Scalar or schedule (callable taking step index, returning scalar).

1.0
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
SGLDState

Updated state (which are pointers to the input state tensors if inplace=True)

TensorTree

and auxiliary information.

Source code in posteriors/sgmcmc/sgld.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def update(
    state: SGLDState,
    batch: Any,
    log_posterior: LogProbFn,
    lr: float | Schedule,
    beta: float = 0.0,
    temperature: float | Schedule = 1.0,
    inplace: bool = False,
) -> tuple[SGLDState, TensorTree]:
    """Updates parameters for SGLD.

    Update rule from [Welling and Teh, 2011](https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf),
    see [build](sgld.md#posteriors.sgmcmc.sgld.build) for details.

    Args:
        state: SGLDState containing params.
        batch: Data batch to be send to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate,
            scalar or schedule (callable taking step index, returning scalar).
        beta: Gradient noise coefficient (estimated variance).
        temperature: Temperature of the sampling distribution.
            Scalar or schedule (callable taking step index, returning scalar).
        inplace: Whether to modify state in place.

    Returns:
        Updated state (which are pointers to the input state tensors if inplace=True)
        and auxiliary information.
    """
    with torch.no_grad(), CatchAuxError():
        grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
            state.params, batch
        )

    lr = lr(state.step) if callable(lr) else lr
    temperature = temperature(state.step) if callable(temperature) else temperature

    def transform_params(p, g):
        return (
            p
            + lr * g
            + (temperature * lr * (2 - temperature * lr * beta)) ** 0.5
            * torch.randn_like(p)
        )

    params = flexi_tree_map(transform_params, state.params, grads, inplace=inplace)

    if inplace:
        tree_insert_(state.log_posterior, log_post.detach())
        tree_insert_(state.step, state.step + 1)
        return state, aux
    return SGLDState(params, log_post.detach(), state.step + 1), aux