Skip to content

Optim

posteriors.optim.build(loss_fn, optimizer, **kwargs) 𝞡

Builds an optimizer transform from torch.optim

transform = build(loss_fn, torch.optim.Adam, lr=0.1)
state = transform.init(params)

for batch in dataloader:
    state = transform.update(state, batch)

Parameters:

Name Type Description Default
loss_fn LogProbFn

Function that takes the parameters and returns the loss. of the form loss, aux = fn(params, batch).

required
optimizer Type[Optimizer]

Optimizer class from torch.optim.

required
**kwargs Any

Keyword arguments to pass to the optimizer class.

{}

Returns:

Type Description
Transform

torch.optim transform instance.

Source code in posteriors/optim.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def build(
    loss_fn: LogProbFn,
    optimizer: Type[torch.optim.Optimizer],
    **kwargs: Any,
) -> Transform:
    """Builds an optimizer transform from [torch.optim](https://pytorch.org/docs/stable/optim.html)

    ```
    transform = build(loss_fn, torch.optim.Adam, lr=0.1)
    state = transform.init(params)

    for batch in dataloader:
        state = transform.update(state, batch)
    ```

    Args:
        loss_fn: Function that takes the parameters and returns the loss.
            of the form `loss, aux = fn(params, batch)`.
        optimizer: Optimizer class from torch.optim.
        **kwargs: Keyword arguments to pass to the optimizer class.

    Returns:
        `torch.optim` transform instance.
    """
    init_fn = partial(init, optimizer_cls=optimizer, **kwargs)
    update_fn = partial(update, loss_fn=loss_fn)
    return Transform(init_fn, update_fn)

posteriors.optim.OptimState 𝞡

Bases: NamedTuple

State of an optimizer from torch.optim.

Attributes:

Name Type Description
params TensorTree

Parameters to be optimized.

optimizer Optimizer

torch.optim optimizer instance.

loss tensor

Loss value.

aux Any

Auxiliary information from the loss function call.

Source code in posteriors/optim.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class OptimState(NamedTuple):
    """State of an optimizer from [torch.optim](https://pytorch.org/docs/stable/optim.html).

    Attributes:
        params: Parameters to be optimized.
        optimizer: torch.optim optimizer instance.
        loss: Loss value.
        aux: Auxiliary information from the loss function call.
    """

    params: TensorTree
    optimizer: torch.optim.Optimizer
    loss: torch.tensor = torch.tensor([])
    aux: Any = None

posteriors.optim.init(params, optimizer_cls, *args, **kwargs) 𝞡

Initialise a torch.optim optimizer state.

Parameters:

Name Type Description Default
params TensorTree

Parameters to be optimized.

required
optimizer_cls Type[Optimizer]

Optimizer class from torch.optim.

required
*args Any

Positional arguments to pass to the optimizer class.

()
**kwargs Any

Keyword arguments to pass to the optimizer class.

{}

Returns:

Type Description
OptimState

Initial OptimState.

Source code in posteriors/optim.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def init(
    params: TensorTree,
    optimizer_cls: Type[torch.optim.Optimizer],
    *args: Any,
    **kwargs: Any,
) -> OptimState:
    """Initialise a [torch.optim](https://pytorch.org/docs/stable/optim.html) optimizer
    state.

    Args:
        params: Parameters to be optimized.
        optimizer_cls: Optimizer class from torch.optim.
        *args: Positional arguments to pass to the optimizer class.
        **kwargs: Keyword arguments to pass to the optimizer class.

    Returns:
        Initial OptimState.
    """
    opt_params = [params] if isinstance(params, torch.Tensor) else params

    optimizer = optimizer_cls(opt_params, *args, **kwargs)
    return OptimState(params, optimizer)

posteriors.optim.update(state, batch, loss_fn, inplace=True) 𝞡

Perform a single update step of a torch.optim optimizer.

Parameters:

Name Type Description Default
state OptimState

Current optimizer state.

required
batch TensorTree

Input data to loss_fn.

required
loss_fn LogProbFn

Function that takes the parameters and returns the loss. of the form loss, aux = fn(params, batch).

required
inplace bool

Whether to update the parameters in place. inplace=False not supported for posteriors.optim

True

Returns:

Type Description
OptimState

Updated OptimState.

Source code in posteriors/optim.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def update(
    state: OptimState,
    batch: TensorTree,
    loss_fn: LogProbFn,
    inplace: bool = True,
) -> OptimState:
    """Perform a single update step of a [torch.optim](https://pytorch.org/docs/stable/optim.html)
    optimizer.

    Args:
        state: Current optimizer state.
        batch: Input data to loss_fn.
        loss_fn: Function that takes the parameters and returns the loss.
            of the form `loss, aux = fn(params, batch)`.
        inplace: Whether to update the parameters in place.
            inplace=False not supported for posteriors.optim

    Returns:
        Updated OptimState.
    """
    if not inplace:
        raise NotImplementedError("inplace=False not supported for posteriors.optim")
    state.optimizer.zero_grad()
    with CatchAuxError():
        loss, aux = loss_fn(state.params, batch)
    loss.backward()
    state.optimizer.step()
    tree_insert_(state.loss, loss.detach())
    return state._replace(aux=aux)