Skip to content

SGNHT𝞡

posteriors.sgmcmc.sgnht.build(log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, momenta=None, xi=None) 𝞡

Builds SGNHT transform.

Algorithm from Ding et al, 2014:

\[\begin{align} θ_{t+1} &= θ_t + ε σ^{-2} m_t \\ m_{t+1} &= m_t + ε \nabla \log p(θ_t, \text{batch}) - ε σ^{-2} ξ_t m_t + N(0, ε T (2 α - ε β T) \mathbb{I})\\ ξ_{t+1} &= ξ_t + ε (σ^{-2} d^{-1} m_t^T m_t - T) \end{align}\]

for learning rate \(\epsilon\), temperature \(T\) and parameter dimension \(d\).

Targets \(p_T(θ, m, ξ) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm - \frac{d}{2}(ξ - α)^2) / T)\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float

Learning rate.

required
alpha float

Friction coefficient.

0.01
beta float

Gradient noise coefficient (estimated variance).

0.0
sigma float

Standard deviation of momenta target distribution.

1.0
temperature float

Temperature of the joint parameter + momenta distribution.

1.0
momenta TensorTree | float | None

Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1).

None
xi float

Initial value for scalar thermostat ξ. Defaults to alpha.

None

Returns:

Type Description
Transform

SGNHT transform instance.

Source code in posteriors/sgmcmc/sgnht.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def build(
    log_posterior: LogProbFn,
    lr: float,
    alpha: float = 0.01,
    beta: float = 0.0,
    sigma: float = 1.0,
    temperature: float = 1.0,
    momenta: TensorTree | float | None = None,
    xi: float = None,
) -> Transform:
    """Builds SGNHT transform.

    Algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):

    \\begin{align}
    θ_{t+1} &= θ_t + ε σ^{-2} m_t \\\\
    m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε σ^{-2} ξ_t m_t
    + N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\
    ξ_{t+1} &= ξ_t + ε (σ^{-2} d^{-1} m_t^T m_t - T)
    \\end{align}

    for learning rate $\\epsilon$, temperature $T$ and parameter dimension $d$.

    Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac{1}{2σ^2} m^Tm - \\frac{d}{2}(ξ - α)^2) / T)$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate.
        alpha: Friction coefficient.
        beta: Gradient noise coefficient (estimated variance).
        sigma: Standard deviation of momenta target distribution.
        temperature: Temperature of the joint parameter + momenta distribution.
        momenta: Initial momenta. Can be tree like params or scalar.
            Defaults to random iid samples from N(0, 1).
        xi: Initial value for scalar thermostat ξ. Defaults to `alpha`.

    Returns:
        SGNHT transform instance.
    """
    init_fn = partial(init, momenta=momenta, xi=xi or alpha)
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        lr=lr,
        alpha=alpha,
        beta=beta,
        sigma=sigma,
        temperature=temperature,
    )
    return Transform(init_fn, update_fn)

posteriors.sgmcmc.sgnht.SGNHTState 𝞡

Bases: NamedTuple

State encoding params and momenta for SGNHT.

Attributes:

Name Type Description
params TensorTree

Parameters.

momenta TensorTree

Momenta for each parameter.

log_posterior tensor

Log posterior evaluation.

aux Any

Auxiliary information from the log_posterior call.

Source code in posteriors/sgmcmc/sgnht.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class SGNHTState(NamedTuple):
    """State encoding params and momenta for SGNHT.

    Attributes:
        params: Parameters.
        momenta: Momenta for each parameter.
        log_posterior: Log posterior evaluation.
        aux: Auxiliary information from the log_posterior call.
    """

    params: TensorTree
    momenta: TensorTree
    xi: torch.tensor = torch.tensor([])
    log_posterior: torch.tensor = torch.tensor([])
    aux: Any = None

posteriors.sgmcmc.sgnht.init(params, momenta=None, xi=0.01) 𝞡

Initialise momenta for SGNHT.

Parameters:

Name Type Description Default
params TensorTree

Parameters for which to initialise.

required
momenta TensorTree | float | None

Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1).

None
xi float | Tensor

Initial value for scalar thermostat ξ.

0.01

Returns:

Type Description
SGNHTState

Initial SGNHTState containing momenta.

Source code in posteriors/sgmcmc/sgnht.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def init(
    params: TensorTree,
    momenta: TensorTree | float | None = None,
    xi: float | torch.Tensor = 0.01,
) -> SGNHTState:
    """Initialise momenta for SGNHT.

    Args:
        params: Parameters for which to initialise.
        momenta: Initial momenta. Can be tree like params or scalar.
            Defaults to random iid samples from N(0, 1).
        xi: Initial value for scalar thermostat ξ.

    Returns:
        Initial SGNHTState containing momenta.
    """
    if momenta is None:
        momenta = tree_map(
            lambda x: torch.randn_like(x, requires_grad=x.requires_grad),
            params,
        )
    elif is_scalar(momenta):
        momenta = tree_map(
            lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad),
            params,
        )

    return SGNHTState(params, momenta, torch.tensor(xi))

posteriors.sgmcmc.sgnht.update(state, batch, log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, inplace=False) 𝞡

Updates parameters, momenta and xi for SGNHT.

Update rule from Ding et al, 2014:

\[\begin{align} θ_{t+1} &= θ_t + ε σ^{-2} m_t \\ m_{t+1} &= m_t + ε \nabla \log p(θ_t, \text{batch}) - ε σ^{-2} ξ_t m_t + N(0, ε T (2 α - ε β T) \mathbb{I})\\ ξ_{t+1} &= ξ_t + ε (σ^{-2} d^{-1} m_t^T m_t - T) \end{align}\]

for learning rate \(\epsilon\) and temperature \(T\)

Parameters:

Name Type Description Default
state SGNHTState

SGNHTState containing params, momenta and xi.

required
batch Any

Data batch to be send to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float

Learning rate.

required
alpha float

Friction coefficient.

0.01
beta float

Gradient noise coefficient (estimated variance).

0.0
sigma float

Standard deviation of momenta target distribution.

1.0
temperature float

Temperature of the joint parameter + momenta distribution.

1.0
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
SGNHTState

Updated SGNHTState

SGNHTState

(which are pointers to the inputted state tensors if inplace=True).

Source code in posteriors/sgmcmc/sgnht.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def update(
    state: SGNHTState,
    batch: Any,
    log_posterior: LogProbFn,
    lr: float,
    alpha: float = 0.01,
    beta: float = 0.0,
    sigma: float = 1.0,
    temperature: float = 1.0,
    inplace: bool = False,
) -> SGNHTState:
    """Updates parameters, momenta and xi for SGNHT.

    Update rule from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):

    \\begin{align}
    θ_{t+1} &= θ_t + ε σ^{-2} m_t \\\\
    m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε σ^{-2} ξ_t m_t
    + N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\
    ξ_{t+1} &= ξ_t + ε (σ^{-2} d^{-1} m_t^T m_t - T)
    \\end{align}

    for learning rate $\\epsilon$ and temperature $T$

    Args:
        state: SGNHTState containing params, momenta and xi.
        batch: Data batch to be send to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate.
        alpha: Friction coefficient.
        beta: Gradient noise coefficient (estimated variance).
        sigma: Standard deviation of momenta target distribution.
        temperature: Temperature of the joint parameter + momenta distribution.
        inplace: Whether to modify state in place.

    Returns:
        Updated SGNHTState
        (which are pointers to the inputted state tensors if inplace=True).
    """
    with torch.no_grad(), CatchAuxError():
        grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
            state.params, batch
        )

    prec = sigma**-2

    def transform_params(p, m):
        return p + lr * prec * m

    def transform_momenta(m, g):
        return (
            m
            + lr * g
            - lr * prec * state.xi * m
            + (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5
            * torch.randn_like(m)
        )

    m_flat, _ = tree_ravel(state.momenta)
    xi_new = state.xi + lr * (prec * torch.mean(m_flat**2) - temperature)

    params = flexi_tree_map(
        transform_params, state.params, state.momenta, inplace=inplace
    )
    momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace)

    if inplace:
        tree_insert_(state.xi, xi_new)
        tree_insert_(state.log_posterior, log_post.detach())
        return state._replace(aux=aux)
    return SGNHTState(params, momenta, xi_new, log_post.detach(), aux)