Skip to content

SGLD𝞡

posteriors.sgmcmc.sgld.build(log_posterior, lr, beta=0.0, temperature=1.0) 𝞡

Builds SGLD transform.

Algorithm from Welling and Teh, 2011: $$ θ_{t+1} = θ_t + ε \nabla \log p(θ_t, \text{batch}) + N(0, ε (2 - ε β) T \mathbb{I}) $$ for learning rate \(\epsilon\) and temperature \(T\).

Targets \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float

Learning rate.

required
beta float

Gradient noise coefficient (estimated variance).

0.0
temperature float

Temperature of the sampling distribution.

1.0

Returns:

Type Description
Transform

SGLD transform (posteriors.types.Transform instance).

Source code in posteriors/sgmcmc/sgld.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def build(
    log_posterior: LogProbFn,
    lr: float,
    beta: float = 0.0,
    temperature: float = 1.0,
) -> Transform:
    """Builds SGLD transform.

    Algorithm from [Welling and Teh, 2011](https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf):
    $$
    θ_{t+1} = θ_t + ε \\nabla \\log p(θ_t, \\text{batch}) + N(0, ε  (2 - ε β) T \\mathbb{I})
    $$
    for learning rate $\\epsilon$ and temperature $T$.

    Targets $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate.
        beta: Gradient noise coefficient (estimated variance).
        temperature: Temperature of the sampling distribution.

    Returns:
        SGLD transform (posteriors.types.Transform instance).
    """
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        lr=lr,
        beta=beta,
        temperature=temperature,
    )
    return Transform(init, update_fn)

posteriors.sgmcmc.sgld.SGLDState 𝞡

Bases: NamedTuple

State encoding params for SGLD.

Attributes:

Name Type Description
params TensorTree

Parameters.

log_posterior tensor

Log posterior evaluation.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/sgmcmc/sgld.py
51
52
53
54
55
56
57
58
59
60
61
62
class SGLDState(NamedTuple):
    """State encoding params for SGLD.

    Attributes:
        params: Parameters.
        log_posterior: Log posterior evaluation.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    log_posterior: torch.tensor = torch.tensor([])
    aux: Any = None

posteriors.sgmcmc.sgld.init(params) 𝞡

Initialise SGLD.

Parameters:

Name Type Description Default
params TensorTree

Parameters for which to initialise.

required

Returns:

Type Description
SGLDState

Initial SGLDState.

Source code in posteriors/sgmcmc/sgld.py
65
66
67
68
69
70
71
72
73
74
75
def init(params: TensorTree) -> SGLDState:
    """Initialise SGLD.

    Args:
        params: Parameters for which to initialise.

    Returns:
        Initial SGLDState.
    """

    return SGLDState(params)

posteriors.sgmcmc.sgld.update(state, batch, log_posterior, lr, beta=0.0, temperature=1.0, inplace=False) 𝞡

Updates parameters for SGLD.

Update rule from Welling and Teh, 2011: $$ θ_{t+1} = θ_t + ε \nabla \log p(θ_t, \text{batch}) + N(0, ε (2 - ε β) T \mathbb{I}) $$ for lr \(\epsilon\) and temperature \(T\).

Parameters:

Name Type Description Default
state SGLDState

SGLDState containing params.

required
batch Any

Data batch to be send to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float

Learning rate.

required
beta float

Gradient noise coefficient (estimated variance).

0.0
temperature float

Temperature of the sampling distribution.

1.0
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
SGLDState

Updated state (which are pointers to the input state tensors if inplace=True).

Source code in posteriors/sgmcmc/sgld.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def update(
    state: SGLDState,
    batch: Any,
    log_posterior: LogProbFn,
    lr: float,
    beta: float = 0.0,
    temperature: float = 1.0,
    inplace: bool = False,
) -> SGLDState:
    """Updates parameters for SGLD.

    Update rule from [Welling and Teh, 2011](https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf):
    $$
    θ_{t+1} = θ_t + ε \\nabla \\log p(θ_t, \\text{batch}) + N(0, ε  (2 - ε β) T \\mathbb{I})
    $$
    for lr $\\epsilon$ and temperature $T$.

    Args:
        state: SGLDState containing params.
        batch: Data batch to be send to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate.
        beta: Gradient noise coefficient (estimated variance).
        temperature: Temperature of the sampling distribution.
        inplace: Whether to modify state in place.

    Returns:
        Updated state (which are pointers to the input state tensors if inplace=True).
    """
    with torch.no_grad(), CatchAuxError():
        grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
            state.params, batch
        )

    def transform_params(p, g):
        return (
            p
            + lr * g
            + (temperature * lr * (2 - temperature * lr * beta)) ** 0.5
            * torch.randn_like(p)
        )

    params = flexi_tree_map(transform_params, state.params, grads, inplace=inplace)

    if inplace:
        tree_insert_(state.log_posterior, log_post.detach())
        return state._replace(aux=aux)
    return SGLDState(params, log_post.detach(), aux)