Skip to content

BAOA𝞡

posteriors.sgmcmc.baoa.build(log_posterior, lr, alpha=0.01, sigma=1.0, temperature=1.0, momenta=None) 𝞡

Builds BAOA transform.

Algorithm from Leimkuhler and Matthews, 2015 - p271.

BAOA is conjugate to BAOAB (in Leimkuhler and Matthews' terminology) but requires only a single gradient evaluation per iteration. The two are equivalent when analyzing functions of the parameter trajectory. Unlike BAOAB, BAOA is not reversible, but since we don't apply Metropolis-Hastings or momenta reversal, the algorithm remains functionally identical to BAOAB.

\[\begin{align} m_{t+1/2} &= m_t + ε \nabla \log p(θ_t, \text{batch}), \\ θ_{t+1/2} &= θ_t + (ε / 2) σ^{-2} m_{t+1/2}, \\ m_{t+1} &= e^{-ε γ} m_{t+1/2} + N(0, ζ^2 σ^2), \\ θ_{t+1} &= θ_{t+1/2} + (ε / 2) σ^{-2} m_{t+1} \ \end{align}\]

for learning rate \(\epsilon\), temperature \(T\), transformed friction \(γ = α σ^{-2}\) and transformed noise variance\(ζ^2 = T(1 - e^{-2γε})\).

Targets \(p_T(θ, m) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm) / T)\) with temperature \(T\).

The log posterior and temperature are recommended to be constructed in tandem to ensure robust scaling for a large amount of data and variable batch size.

Parameters:

Name Type Description Default
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float | Schedule

Learning rate. Scalar or schedule (callable taking step index, returning scalar).

required
alpha float

Friction coefficient.

0.01
sigma float

Standard deviation of momenta target distribution.

1.0
temperature float | Schedule

Temperature of the joint parameter + momenta distribution. Scalar or schedule (callable taking step index, returning scalar).

1.0
momenta TensorTree | float | None

Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1).

None

Returns:

Type Description
Transform

BAOA transform instance.

Source code in posteriors/sgmcmc/baoa.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def build(
    log_posterior: LogProbFn,
    lr: float | Schedule,
    alpha: float = 0.01,
    sigma: float = 1.0,
    temperature: float | Schedule = 1.0,
    momenta: TensorTree | float | None = None,
) -> Transform:
    """Builds BAOA transform.

    Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8).

    BAOA is conjugate to BAOAB (in Leimkuhler and Matthews' terminology) but requires
    only a single gradient evaluation per iteration.
    The two are equivalent when analyzing functions of the parameter trajectory.
    Unlike BAOAB, BAOA is not reversible, but since we don't apply Metropolis-Hastings 
    or momenta reversal, the algorithm remains functionally identical to BAOAB.

    \\begin{align}
    m_{t+1/2} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}), \\\\
    θ_{t+1/2} &= θ_t + (ε / 2) σ^{-2} m_{t+1/2}, \\\\
    m_{t+1} &= e^{-ε γ} m_{t+1/2} + N(0, ζ^2 σ^2), \\\\
    θ_{t+1} &= θ_{t+1/2} + (ε / 2) σ^{-2} m_{t+1} \\
    \\end{align}

    for learning rate $\\epsilon$, temperature $T$, transformed friction $γ = α σ^{-2}$
    and transformed noise variance$ζ^2 = T(1 - e^{-2γε})$.

    Targets $p_T(θ, m) \\propto \\exp( (\\log p(θ) - \\frac{1}{2σ^2} m^Tm) / T)$
    with temperature $T$.

    The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
    to ensure robust scaling for a large amount of data and variable batch size.

    Args:
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate.
            Scalar or schedule (callable taking step index, returning scalar).
        alpha: Friction coefficient.
        sigma: Standard deviation of momenta target distribution.
        temperature: Temperature of the joint parameter + momenta distribution.
            Scalar or schedule (callable taking step index, returning scalar).
        momenta: Initial momenta. Can be tree like params or scalar.
            Defaults to random iid samples from N(0, 1).

    Returns:
        BAOA transform instance.
    """
    init_fn = partial(init, momenta=momenta)
    update_fn = partial(
        update,
        log_posterior=log_posterior,
        lr=lr,
        alpha=alpha,
        sigma=sigma,
        temperature=temperature,
    )
    return Transform(init_fn, update_fn)

posteriors.sgmcmc.baoa.BAOAState 𝞡

Bases: TensorClass['frozen']

State encoding params and momenta for BAOA.

Attributes:

Name Type Description
params TensorTree

Parameters.

momenta TensorTree

Momenta for each parameter.

log_posterior Tensor

Log posterior evaluation.

step Tensor

Current step count.

Source code in posteriors/sgmcmc/baoa.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class BAOAState(TensorClass["frozen"]):
    """State encoding params and momenta for BAOA.

    Attributes:
        params: Parameters.
        momenta: Momenta for each parameter.
        log_posterior: Log posterior evaluation.
        step: Current step count.
    """

    params: TensorTree
    momenta: TensorTree
    log_posterior: torch.Tensor = torch.tensor(torch.nan)
    step: torch.Tensor = torch.tensor(0)

posteriors.sgmcmc.baoa.init(params, momenta=None) 𝞡

Initialise momenta for BAOA.

Parameters:

Name Type Description Default
params TensorTree

Parameters for which to initialise.

required
momenta TensorTree | float | None

Initial momenta. Can be tree like params or scalar. Defaults to random iid samples from N(0, 1).

None

Returns:

Type Description
BAOAState

Initial BAOAState containing momenta.

Source code in posteriors/sgmcmc/baoa.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def init(params: TensorTree, momenta: TensorTree | float | None = None) -> BAOAState:
    """Initialise momenta for BAOA.

    Args:
        params: Parameters for which to initialise.
        momenta: Initial momenta. Can be tree like params or scalar.
            Defaults to random iid samples from N(0, 1).

    Returns:
        Initial BAOAState containing momenta.
    """
    if momenta is None:
        momenta = tree_map(
            lambda x: torch.randn_like(x, requires_grad=x.requires_grad),
            params,
        )
    elif is_scalar(momenta):
        momenta = tree_map(
            lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad),
            params,
        )

    return BAOAState(params, momenta)

posteriors.sgmcmc.baoa.update(state, batch, log_posterior, lr, alpha=0.01, sigma=1.0, temperature=1.0, inplace=False) 𝞡

Updates parameters and momenta for BAOA.

Algorithm from Leimkuhler and Matthews, 2015 - p271.

See build for more details.

Parameters:

Name Type Description Default
state BAOAState

BAOAState containing params and momenta.

required
batch Any

Data batch to be send to log_posterior.

required
log_posterior LogProbFn

Function that takes parameters and input batch and returns the log posterior value (which can be unnormalised) as well as auxiliary information, e.g. from the model call.

required
lr float | Schedule

Learning rate. Scalar or schedule (callable taking step index, returning scalar).

required
alpha float

Friction coefficient.

0.01
sigma float

Standard deviation of momenta target distribution.

1.0
temperature float | Schedule

Temperature of the joint parameter + momenta distribution. Scalar or schedule (callable taking step index, returning scalar).

1.0
inplace bool

Whether to modify state in place.

False

Returns:

Type Description
BAOAState

Updated state

TensorTree

(which are pointers to the inputted state tensors if inplace=True)

tuple[BAOAState, TensorTree]

and auxiliary information.

Source code in posteriors/sgmcmc/baoa.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def update(
    state: BAOAState,
    batch: Any,
    log_posterior: LogProbFn,
    lr: float | Schedule,
    alpha: float = 0.01,
    sigma: float = 1.0,
    temperature: float | Schedule = 1.0,
    inplace: bool = False,
) -> tuple[BAOAState, TensorTree]:
    """Updates parameters and momenta for BAOA.

    Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8).

    See [build](baoa.md#posteriors.sgmcmc.baoa.build) for more details.

    Args:
        state: BAOAState containing params and momenta.
        batch: Data batch to be send to log_posterior.
        log_posterior: Function that takes parameters and input batch and
            returns the log posterior value (which can be unnormalised)
            as well as auxiliary information, e.g. from the model call.
        lr: Learning rate.
            Scalar or schedule (callable taking step index, returning scalar).
        alpha: Friction coefficient.
        sigma: Standard deviation of momenta target distribution.
        temperature: Temperature of the joint parameter + momenta distribution.
            Scalar or schedule (callable taking step index, returning scalar).
        inplace: Whether to modify state in place.

    Returns:
        Updated state
        (which are pointers to the inputted state tensors if inplace=True)
        and auxiliary information.
    """
    with torch.no_grad(), CatchAuxError():
        grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
            state.params, batch
        )

    lr = lr(state.step) if callable(lr) else lr
    temperature = temperature(state.step) if callable(temperature) else temperature
    prec = sigma**-2
    gamma = torch.tensor(alpha * prec)
    zeta2 = (temperature * (1 - torch.exp(-2 * gamma * lr))) ** 0.5

    def BB_step(m, g):
        return m + lr * g

    def A_step(p, m):
        return p + (lr / 2) * prec * m

    def O_step(m):
        return torch.exp(-gamma * lr) * m + zeta2 * sigma * torch.randn_like(m)

    momenta = flexi_tree_map(BB_step, state.momenta, grads, inplace=inplace)
    params = flexi_tree_map(A_step, state.params, momenta, inplace=inplace)
    momenta = flexi_tree_map(O_step, momenta, inplace=inplace)
    params = flexi_tree_map(A_step, params, momenta, inplace=inplace)

    if inplace:
        tree_insert_(state.log_posterior, log_post.detach())
        tree_insert_(state.step, state.step + 1)
        return state, aux
    return BAOAState(params, momenta, log_post.detach(), state.step + 1), aux