Skip to content

SGLRW𝞡

posteriors.sgmcmc.sglrw.build(log_posterior, lr, temperature=1.0) 𝞡

Builds SGLRW transform - Stochastic Gradient Lattice Random Walk.

Algorithm from Mensch et al, 2026 adapted from Duffield et al, 2025: $$ θ_{t+1} = θ_t + δx Δ(θₜ, t) $$ where \(δx = √(lr * 2 * T)\) is a spatial stepsize and \(Δ(θₜ, t)\) is a random binary valued vector defined in the paper.

Targets \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\), as it discretizes the overdamped Langevin SDE: $$ dθ = ∇ log p_T(θ) dt + √(2 T) dW $$

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float | Schedule

Learning rate, scalar or schedule (callable taking step index, returning scalar).

required
temperature float | Schedule

Temperature of the sampling distribution. Scalar or schedule (callable taking step index, returning scalar).

1.0

Returns:

Type Description
Transform

SGLRW transform (posteriors.types.Transform instance).

Source code in posteriors/sgmcmc/sglrw.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def build(
    log_posterior: LogProbFn,
    lr: float | Schedule,
    temperature: float | Schedule = 1.0,
) -> Transform:
    """Builds SGLRW transform - Stochastic Gradient Lattice Random Walk.

    Algorithm from [Mensch et al, 2026](https://arxiv.org/abs/2602.15925)
    adapted from [Duffield et al, 2025](https://arxiv.org/abs/2508.20883):
    $$
    θ_{t+1} = θ_t + δx Δ(θₜ, t)
    $$
    where $δx = √(lr * 2 * T)$ is a spatial stepsize and $Δ(θₜ, t)$ is a random
    binary valued vector defined in the paper.

    Targets $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$,
    as it discretizes the overdamped Langevin SDE:
    $$
    dθ = ∇ log p_T(θ) dt + √(2 T) dW
    $$

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate,
            scalar or schedule (callable taking step index, returning scalar).
        temperature: Temperature of the sampling distribution.
            Scalar or schedule (callable taking step index, returning scalar).

    Returns:
        SGLRW transform (posteriors.types.Transform instance).
    """
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        lr=lr,
        temperature=temperature,
    )
    return Transform(init, update_fn)

posteriors.sgmcmc.sglrw.SGLRWState 𝞡

Bases: TensorClass['frozen']

State encoding params for SG-LRW (binary).

Attributes:

Name Type Description
params TensorTree

Parameters.

log_posterior Tensor

Last log posterior evaluation.

step Tensor

Current step count.

Source code in posteriors/sgmcmc/sglrw.py
59
60
61
62
63
64
65
66
67
68
69
70
class SGLRWState(TensorClass["frozen"]):
    """State encoding params for SG-LRW (binary).

    Attributes:
        params: Parameters.
        log_posterior: Last log posterior evaluation.
        step: Current step count.
    """

    params: TensorTree
    log_posterior: Tensor = torch.tensor(torch.nan)
    step: Tensor = torch.tensor(0)

posteriors.sgmcmc.sglrw.init(params) 𝞡

Initialise SG-LRW.

Source code in posteriors/sgmcmc/sglrw.py
73
74
75
def init(params: TensorTree) -> SGLRWState:
    """Initialise SG-LRW."""
    return SGLRWState(params)

posteriors.sgmcmc.sglrw.ternary_probs(drift_val, diffusion_val, stepsize, delta_x) 𝞡

Generate the probabilities for the ternary update from the discretization parameters.

Parameters:

Name Type Description Default
drift_val Tensor

Evaluation of the Drift function.

required
diffusion_val Tensor

Evaluation of the Diffusion function.

required
stepsize Tensor

Temporal stepsize value.

required
delta_x Tensor

Spatial stepsize value.

required

Returns:

Type Description
Tensor

Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus].

Source code in posteriors/sgmcmc/sglrw.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def ternary_probs(
    drift_val: Tensor,
    diffusion_val: Tensor,
    stepsize: Tensor,
    delta_x: Tensor,
) -> Tensor:
    """
    Generate the probabilities for the ternary update
    from the discretization parameters.

    Args:
        drift_val: Evaluation of the Drift function.
        diffusion_val: Evaluation of the Diffusion function.
        stepsize: Temporal stepsize value.
        delta_x: Spatial stepsize value.

    Returns:
        Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus].
    """
    desired_mean = stepsize * drift_val
    desired_var = stepsize * diffusion_val**2
    scaled_mean = desired_mean / delta_x
    scaled_var = desired_var / delta_x**2

    # Ensure p_minus + p_plus <= 1
    scaled_var = torch.clamp(scaled_var, 0.0, 1.0)

    # Ensure positive probs
    scaled_mean = torch.clamp(scaled_mean, -scaled_var, scaled_var)

    # Clip probs for numerical stability
    p_plus = torch.clamp(0.5 * (scaled_var + scaled_mean), 0.0, 1.0)
    p_minus = torch.clamp(0.5 * (scaled_var - scaled_mean), 0.0, 1.0)
    p_zero = torch.clamp(1 - p_plus - p_minus, 0.0, 1.0)

    return torch.stack([p_minus, p_zero, p_plus], dim=-1)