Skip to content

TorchOpt

posteriors.torchopt.build(loss_fn, optimizer) 𝞡

Build a TorchOpt optimizer transformation.

Make sure to use the lower case functional optimizers e.g. torchopt.adam().

transform = build(loss_fn, torchopt.adam(lr=0.1))
state = transform.init(params)

for batch in dataloader:
    state = transform.update(state, batch)

Parameters:

Name Type Description Default
loss_fn LogProbFn

Loss function.

required
optimizer GradientTransformation

TorchOpt functional optimizer. Make sure to use lower case e.g. torchopt.adam()

required

Returns:

Type Description
Transform

Torchopt optimizer transform instance.

Source code in posteriors/torchopt.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def build(
    loss_fn: LogProbFn,
    optimizer: torchopt.base.GradientTransformation,
) -> Transform:
    """Build a [TorchOpt](https://github.com/metaopt/torchopt) optimizer transformation.

    Make sure to use the lower case functional optimizers e.g. `torchopt.adam()`.

    ```
    transform = build(loss_fn, torchopt.adam(lr=0.1))
    state = transform.init(params)

    for batch in dataloader:
        state = transform.update(state, batch)
    ```

    Args:
        loss_fn: Loss function.
        optimizer: TorchOpt functional optimizer.
            Make sure to use lower case e.g. torchopt.adam()

    Returns:
        Torchopt optimizer transform instance.
    """
    init_fn = partial(init, optimizer=optimizer)
    update_fn = partial(update, optimizer=optimizer, loss_fn=loss_fn)
    return Transform(init_fn, update_fn)

posteriors.torchopt.TorchOptState 𝞡

Bases: NamedTuple

State of a TorchOpt optimizer.

Contains the parameters, the optimizer state for the TorchOpt optimizer, loss value, and auxiliary information.

Attributes:

Name Type Description
params TensorTree

Parameters to be optimized.

opt_state OptState

TorchOpt optimizer state.

loss tensor

Loss value.

aux Any

Auxiliary information from the loss function call.

Source code in posteriors/torchopt.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class TorchOptState(NamedTuple):
    """State of a [TorchOpt](https://github.com/metaopt/torchopt) optimizer.

    Contains the parameters, the optimizer state for the TorchOpt optimizer,
    loss value, and auxiliary information.

    Attributes:
        params: Parameters to be optimized.
        opt_state: TorchOpt optimizer state.
        loss: Loss value.
        aux: Auxiliary information from the loss function call.
    """

    params: TensorTree
    opt_state: torchopt.typing.OptState
    loss: torch.tensor = torch.tensor([])
    aux: Any = None

posteriors.torchopt.init(params, optimizer) 𝞡

Initialise a TorchOpt optimizer.

Make sure to use the lower case functional optimizers e.g. torchopt.adam().

Parameters:

Name Type Description Default
params TensorTree

Parameters to be optimized.

required
optimizer GradientTransformation

TorchOpt functional optimizer. Make sure to use lower case e.g. torchopt.adam()

required

Returns:

Type Description
TorchOptState

Initial TorchOptState.

Source code in posteriors/torchopt.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def init(
    params: TensorTree,
    optimizer: torchopt.base.GradientTransformation,
) -> TorchOptState:
    """Initialise a [TorchOpt](https://github.com/metaopt/torchopt) optimizer.

    Make sure to use the lower case functional optimizers e.g. `torchopt.adam()`.

    Args:
        params: Parameters to be optimized.
        optimizer: TorchOpt functional optimizer.
            Make sure to use lower case e.g. torchopt.adam()

    Returns:
        Initial TorchOptState.
    """
    opt_state = optimizer.init(params)
    return TorchOptState(params, opt_state)

posteriors.torchopt.update(state, batch, loss_fn, optimizer, inplace=False) 𝞡

Update the TorchOpt optimizer state.

Make sure to use the lower case functional optimizers e.g. torchopt.adam().

Parameters:

Name Type Description Default
state TorchOptState

Current state.

required
batch TensorTree

Batch of data.

required
loss_fn LogProbFn

Loss function.

required
optimizer GradientTransformation

TorchOpt functional optimizer. Make sure to use lower case like torchopt.adam()

required
inplace bool

Whether to update the state in place.

False

Returns:

Type Description
TorchOptState

Updated TorchOptState.

Source code in posteriors/torchopt.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def update(
    state: TorchOptState,
    batch: TensorTree,
    loss_fn: LogProbFn,
    optimizer: torchopt.base.GradientTransformation,
    inplace: bool = False,
) -> TorchOptState:
    """Update the [TorchOpt](https://github.com/metaopt/torchopt) optimizer state.

    Make sure to use the lower case functional optimizers e.g. `torchopt.adam()`.

    Args:
        state: Current state.
        batch: Batch of data.
        loss_fn: Loss function.
        optimizer: TorchOpt functional optimizer.
            Make sure to use lower case like torchopt.adam()
        inplace: Whether to update the state in place.

    Returns:
        Updated TorchOptState.
    """
    params = state.params
    opt_state = state.opt_state
    with torch.no_grad(), CatchAuxError():
        grads, (loss, aux) = torch.func.grad_and_value(loss_fn, has_aux=True)(
            params, batch
        )
    updates, opt_state = optimizer.update(grads, opt_state, params=params)
    params = torchopt.apply_updates(params, updates, inplace=inplace)
    if inplace:
        tree_insert_(state.loss, loss.detach())
        return state._replace(aux=aux)

    return TorchOptState(params, opt_state, loss, aux)