SGNHT𝞡
posteriors.sgmcmc.sgnht.build(log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, momenta=None, xi=None)
𝞡
Builds SGNHT transform.
Algorithm from Ding et al, 2014:
for learning rate \(\epsilon\), temperature \(T\) and parameter dimension \(d\).
Targets \(p_T(θ, m, ξ) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm - \frac{d}{2}(ξ - α)^2) / T)\).
The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Learning rate. |
required |
alpha
|
float
|
Friction coefficient. |
0.01
|
beta
|
float
|
Gradient noise coefficient (estimated variance). |
0.0
|
sigma
|
float
|
Standard deviation of momenta target distribution. |
1.0
|
temperature
|
float
|
Temperature of the joint parameter + momenta distribution. |
1.0
|
momenta
|
TensorTree | float | None
|
Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1). |
None
|
xi
|
float
|
Initial value for scalar thermostat ξ. Defaults to |
None
|
Returns:
Type | Description |
---|---|
Transform
|
SGNHT transform instance. |
Source code in posteriors/sgmcmc/sgnht.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
posteriors.sgmcmc.sgnht.SGNHTState
𝞡
Bases: NamedTuple
State encoding params and momenta for SGNHT.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Parameters. |
momenta |
TensorTree
|
Momenta for each parameter. |
log_posterior |
tensor
|
Log posterior evaluation. |
aux |
Any
|
Auxiliary information from the log_posterior call. |
Source code in posteriors/sgmcmc/sgnht.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|
posteriors.sgmcmc.sgnht.init(params, momenta=None, xi=0.01)
𝞡
Initialise momenta for SGNHT.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Parameters for which to initialise. |
required |
momenta
|
TensorTree | float | None
|
Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1). |
None
|
xi
|
float | Tensor
|
Initial value for scalar thermostat ξ. |
0.01
|
Returns:
Type | Description |
---|---|
SGNHTState
|
Initial SGNHTState containing momenta. |
Source code in posteriors/sgmcmc/sgnht.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
posteriors.sgmcmc.sgnht.update(state, batch, log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, inplace=False)
𝞡
Updates parameters, momenta and xi for SGNHT.
Update rule from Ding et al, 2014:
for learning rate \(\epsilon\) and temperature \(T\)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
SGNHTState
|
SGNHTState containing params, momenta and xi. |
required |
batch
|
Any
|
Data batch to be send to log_posterior. |
required |
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Learning rate. |
required |
alpha
|
float
|
Friction coefficient. |
0.01
|
beta
|
float
|
Gradient noise coefficient (estimated variance). |
0.0
|
sigma
|
float
|
Standard deviation of momenta target distribution. |
1.0
|
temperature
|
float
|
Temperature of the joint parameter + momenta distribution. |
1.0
|
inplace
|
bool
|
Whether to modify state in place. |
False
|
Returns:
Type | Description |
---|---|
SGNHTState
|
Updated SGNHTState |
SGNHTState
|
(which are pointers to the inputted state tensors if inplace=True). |
Source code in posteriors/sgmcmc/sgnht.py
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
|