BAOA𝞡
posteriors.sgmcmc.baoa.build(log_posterior, lr, alpha=0.01, sigma=1.0, temperature=1.0, momenta=None)
𝞡
Builds BAOA transform.
Algorithm from Leimkuhler and Matthews, 2015 - p271.
BAOA is conjugate to BAOAB (in Leimkuhler and Matthews' terminology) but requires only a single gradient evaluation per iteration. The two are equivalent when analyzing functions of the parameter trajectory. Unlike BAOAB, BAOA is not reversible, but since we don't apply Metropolis-Hastings or momenta reversal, the algorithm remains functionally identical to BAOAB.
for learning rate \(\epsilon\), temperature \(T\), transformed friction \(γ = α σ^{-2}\) and transformed noise variance\(ζ^2 = T(1 - e^{-2γε})\).
Targets \(p_T(θ, m) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm) / T)\) with temperature \(T\).
The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float | Schedule
|
Learning rate. Scalar or schedule (callable taking step index, returning scalar). |
required |
alpha
|
float
|
Friction coefficient. |
0.01
|
sigma
|
float
|
Standard deviation of momenta target distribution. |
1.0
|
temperature
|
float | Schedule
|
Temperature of the joint parameter + momenta distribution. Scalar or schedule (callable taking step index, returning scalar). |
1.0
|
momenta
|
TensorTree | float | None
|
Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1). |
None
|
Returns:
Type | Description |
---|---|
Transform
|
BAOA transform instance. |
Source code in posteriors/sgmcmc/baoa.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
posteriors.sgmcmc.baoa.BAOAState
𝞡
Bases: TensorClass['frozen']
State encoding params and momenta for BAOA.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Parameters. |
momenta |
TensorTree
|
Momenta for each parameter. |
log_posterior |
Tensor
|
Log posterior evaluation. |
step |
Tensor
|
Current step count. |
Source code in posteriors/sgmcmc/baoa.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
posteriors.sgmcmc.baoa.init(params, momenta=None)
𝞡
Initialise momenta for BAOA.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Parameters for which to initialise. |
required |
momenta
|
TensorTree | float | None
|
Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1). |
None
|
Returns:
Type | Description |
---|---|
BAOAState
|
Initial BAOAState containing momenta. |
Source code in posteriors/sgmcmc/baoa.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
|
posteriors.sgmcmc.baoa.update(state, batch, log_posterior, lr, alpha=0.01, sigma=1.0, temperature=1.0, inplace=False)
𝞡
Updates parameters and momenta for BAOA.
Algorithm from Leimkuhler and Matthews, 2015 - p271.
See build for more details.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
BAOAState
|
BAOAState containing params and momenta. |
required |
batch
|
Any
|
Data batch to be send to log_posterior. |
required |
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float | Schedule
|
Learning rate. Scalar or schedule (callable taking step index, returning scalar). |
required |
alpha
|
float
|
Friction coefficient. |
0.01
|
sigma
|
float
|
Standard deviation of momenta target distribution. |
1.0
|
temperature
|
float | Schedule
|
Temperature of the joint parameter + momenta distribution. Scalar or schedule (callable taking step index, returning scalar). |
1.0
|
inplace
|
bool
|
Whether to modify state in place. |
False
|
Returns:
Type | Description |
---|---|
BAOAState
|
Updated state |
TensorTree
|
(which are pointers to the inputted state tensors if inplace=True) |
tuple[BAOAState, TensorTree]
|
and auxiliary information. |
Source code in posteriors/sgmcmc/baoa.py
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
|