SGHMC𝞡
posteriors.sgmcmc.sghmc.build(log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, momenta=None)
𝞡
Builds SGHMC transform.
Algorithm from Chen et al, 2014:
for learning rate \(\epsilon\) and temperature \(T\)
Targets \(p_T(θ, m) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm) / T)\) with temperature \(T\).
The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Learning rate. |
required |
alpha
|
float
|
Friction coefficient. |
0.01
|
beta
|
float
|
Gradient noise coefficient (estimated variance). |
0.0
|
sigma
|
float
|
Standard deviation of momenta target distribution. |
1.0
|
temperature
|
float
|
Temperature of the joint parameter + momenta distribution. |
1.0
|
momenta
|
TensorTree | float | None
|
Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1). |
None
|
Returns:
Type | Description |
---|---|
Transform
|
SGHMC transform instance. |
Source code in posteriors/sgmcmc/sghmc.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
posteriors.sgmcmc.sghmc.SGHMCState
𝞡
Bases: NamedTuple
State encoding params and momenta for SGHMC.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Parameters. |
momenta |
TensorTree
|
Momenta for each parameter. |
log_posterior |
tensor
|
Log posterior evaluation. |
aux |
Any
|
Auxiliary information from the log_posterior call. |
Source code in posteriors/sgmcmc/sghmc.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
|
posteriors.sgmcmc.sghmc.init(params, momenta=None)
𝞡
Initialise momenta for SGHMC.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Parameters for which to initialise. |
required |
momenta
|
TensorTree | float | None
|
Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1). |
None
|
Returns:
Type | Description |
---|---|
SGHMCState
|
Initial SGHMCState containing momenta. |
Source code in posteriors/sgmcmc/sghmc.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
posteriors.sgmcmc.sghmc.update(state, batch, log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, inplace=False)
𝞡
Updates parameters and momenta for SGHMC.
Update rule from Chen et al, 2014:
for learning rate \(\epsilon\) and temperature \(T\)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
SGHMCState
|
SGHMCState containing params and momenta. |
required |
batch
|
Any
|
Data batch to be send to log_posterior. |
required |
log_posterior
|
LogProbFn
|
Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call. |
required |
lr
|
float
|
Learning rate. |
required |
alpha
|
float
|
Friction coefficient. |
0.01
|
beta
|
float
|
Gradient noise coefficient (estimated variance). |
0.0
|
sigma
|
float
|
Standard deviation of momenta target distribution. |
1.0
|
temperature
|
float
|
Temperature of the joint parameter + momenta distribution. |
1.0
|
inplace
|
bool
|
Whether to modify state in place. |
False
|
Returns:
Type | Description |
---|---|
SGHMCState
|
Updated state |
SGHMCState
|
(which are pointers to the inputted state tensors if inplace=True). |
Source code in posteriors/sgmcmc/sghmc.py
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
|