SGLRW𝞡
posteriors.sgmcmc.sglrw.build(log_posterior, lr, temperature=1.0)
𝞡
Builds SGLRW transform - Stochastic Gradient Lattice Random Walk.
Algorithm from Mensch et al, 2026 adapted from Duffield et al, 2025: $$ θ_{t+1} = θ_t + δx Δ(θₜ, t) $$ where \(δx = √(lr * 2 * T)\) is a spatial stepsize and \(Δ(θₜ, t)\) is a random binary valued vector defined in the paper.
Targets \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\), as it discretizes the overdamped Langevin SDE: $$ dθ = ∇ log p_T(θ) dt + √(2 T) dW $$
The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float | Schedule
|
Learning rate, scalar or schedule (callable taking step index, returning scalar). |
required |
temperature
|
float | Schedule
|
Temperature of the sampling distribution. Scalar or schedule (callable taking step index, returning scalar). |
1.0
|
Returns:
| Type | Description |
|---|---|
Transform
|
SGLRW transform (posteriors.types.Transform instance). |
Source code in posteriors/sgmcmc/sglrw.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
posteriors.sgmcmc.sglrw.SGLRWState
𝞡
Bases: TensorClass['frozen']
State encoding params for SG-LRW (binary).
Attributes:
| Name | Type | Description |
|---|---|---|
params |
TensorTree
|
Parameters. |
log_posterior |
Tensor
|
Last log posterior evaluation. |
step |
Tensor
|
Current step count. |
Source code in posteriors/sgmcmc/sglrw.py
59 60 61 62 63 64 65 66 67 68 69 70 | |
posteriors.sgmcmc.sglrw.init(params)
𝞡
Initialise SG-LRW.
Source code in posteriors/sgmcmc/sglrw.py
73 74 75 | |
posteriors.sgmcmc.sglrw.ternary_probs(drift_val, diffusion_val, stepsize, delta_x)
𝞡
Generate the probabilities for the ternary update from the discretization parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
drift_val
|
Tensor
|
Evaluation of the Drift function. |
required |
diffusion_val
|
Tensor
|
Evaluation of the Diffusion function. |
required |
stepsize
|
Tensor
|
Temporal stepsize value. |
required |
delta_x
|
Tensor
|
Spatial stepsize value. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus]. |
Source code in posteriors/sgmcmc/sglrw.py
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | |