Skip to content

Optim

posteriors.optim.build(loss_fn, optimizer, **kwargs) 𝞡

Builds an optimizer transform from torch.optim

transform = build(loss_fn, torch.optim.Adam, lr=0.1)
state = transform.init(params)

for batch in dataloader:
    state, aux = transform.update(state, batch)

Parameters:

Name Type Description Default
loss_fn LogProbFn

Function that takes the parameters and returns the loss. of the form loss, aux = fn(params, batch).

required
optimizer Type[Optimizer]

Optimizer class from torch.optim.

required
**kwargs Any

Keyword arguments to pass to the optimizer class.

{}

Returns:

Type Description
Transform

torch.optim transform instance.

Source code in posteriors/optim.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def build(
    loss_fn: LogProbFn,
    optimizer: Type[torch.optim.Optimizer],
    **kwargs: Any,
) -> Transform:
    """Builds an optimizer transform from [torch.optim](https://pytorch.org/docs/stable/optim.html)

    ```
    transform = build(loss_fn, torch.optim.Adam, lr=0.1)
    state = transform.init(params)

    for batch in dataloader:
        state, aux = transform.update(state, batch)
    ```

    Args:
        loss_fn: Function that takes the parameters and returns the loss.
            of the form `loss, aux = fn(params, batch)`.
        optimizer: Optimizer class from torch.optim.
        **kwargs: Keyword arguments to pass to the optimizer class.

    Returns:
        `torch.optim` transform instance.
    """
    init_fn = partial(init, optimizer_cls=optimizer, **kwargs)
    update_fn = partial(update, loss_fn=loss_fn)
    return Transform(init_fn, update_fn)

posteriors.optim.OptimState 𝞡

Bases: TensorClass['frozen']

State of an optimizer from torch.optim.

Attributes:

Name Type Description
params TensorTree

Parameters to be optimized.

optimizer Optimizer

torch.optim optimizer instance.

loss Tensor

Loss value.

Source code in posteriors/optim.py
41
42
43
44
45
46
47
48
49
50
51
52
class OptimState(TensorClass["frozen"]):
    """State of an optimizer from [torch.optim](https://pytorch.org/docs/stable/optim.html).

    Attributes:
        params: Parameters to be optimized.
        optimizer: torch.optim optimizer instance.
        loss: Loss value.
    """

    params: TensorTree
    optimizer: torch.optim.Optimizer
    loss: torch.Tensor = torch.tensor([])

posteriors.optim.init(params, optimizer_cls, *args, **kwargs) 𝞡

Initialise a torch.optim optimizer state.

Parameters:

Name Type Description Default
params TensorTree

Parameters to be optimized.

required
optimizer_cls Type[Optimizer]

Optimizer class from torch.optim.

required
*args Any

Positional arguments to pass to the optimizer class.

()
**kwargs Any

Keyword arguments to pass to the optimizer class.

{}

Returns:

Type Description
OptimState

Initial OptimState.

Source code in posteriors/optim.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def init(
    params: TensorTree,
    optimizer_cls: Type[torch.optim.Optimizer],
    *args: Any,
    **kwargs: Any,
) -> OptimState:
    """Initialise a [torch.optim](https://pytorch.org/docs/stable/optim.html) optimizer
    state.

    Args:
        params: Parameters to be optimized.
        optimizer_cls: Optimizer class from torch.optim.
        *args: Positional arguments to pass to the optimizer class.
        **kwargs: Keyword arguments to pass to the optimizer class.

    Returns:
        Initial OptimState.
    """
    optimizer = optimizer_cls(tree_leaves(params), *args, **kwargs)
    return OptimState(params, optimizer)

posteriors.optim.update(state, batch, loss_fn, inplace=True) 𝞡

Perform a single update step of a torch.optim optimizer.

Parameters:

Name Type Description Default
state OptimState

Current optimizer state.

required
batch TensorTree

Input data to loss_fn.

required
loss_fn LogProbFn

Function that takes the parameters and returns the loss. of the form loss, aux = fn(params, batch).

required
inplace bool

Whether to update the parameters in place. inplace=False not supported for posteriors.optim

True

Returns:

Type Description
tuple[OptimState, TensorTree]

Updated OptimState and auxiliary information.

Source code in posteriors/optim.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def update(
    state: OptimState,
    batch: TensorTree,
    loss_fn: LogProbFn,
    inplace: bool = True,
) -> tuple[OptimState, TensorTree]:
    """Perform a single update step of a [torch.optim](https://pytorch.org/docs/stable/optim.html)
    optimizer.

    Args:
        state: Current optimizer state.
        batch: Input data to loss_fn.
        loss_fn: Function that takes the parameters and returns the loss.
            of the form `loss, aux = fn(params, batch)`.
        inplace: Whether to update the parameters in place.
            inplace=False not supported for posteriors.optim

    Returns:
        Updated OptimState and auxiliary information.
    """
    if not inplace:
        raise NotImplementedError("inplace=False not supported for posteriors.optim")
    state.optimizer.zero_grad()
    with CatchAuxError():
        loss, aux = loss_fn(state.params, batch)
    loss.backward()
    state.optimizer.step()
    tree_insert_(state.loss, loss.detach())
    return state, aux