TorchOpt
posteriors.torchopt.build(loss_fn, optimizer)
𝞡
Build a TorchOpt optimizer transformation.
Make sure to use the lower case functional optimizers e.g. torchopt.adam()
.
transform = build(loss_fn, torchopt.adam(lr=0.1))
state = transform.init(params)
for batch in dataloader:
state = transform.update(state, batch)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn
|
LogProbFn
|
Loss function. |
required |
optimizer
|
GradientTransformation
|
TorchOpt functional optimizer. Make sure to use lower case e.g. torchopt.adam() |
required |
Returns:
Type | Description |
---|---|
Transform
|
Torchopt optimizer transform instance. |
Source code in posteriors/torchopt.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
|
posteriors.torchopt.TorchOptState
𝞡
Bases: NamedTuple
State of a TorchOpt optimizer.
Contains the parameters, the optimizer state for the TorchOpt optimizer, loss value, and auxiliary information.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Parameters to be optimized. |
opt_state |
OptState
|
TorchOpt optimizer state. |
loss |
tensor
|
Loss value. |
aux |
Any
|
Auxiliary information from the loss function call. |
Source code in posteriors/torchopt.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
posteriors.torchopt.init(params, optimizer)
𝞡
Initialise a TorchOpt optimizer.
Make sure to use the lower case functional optimizers e.g. torchopt.adam()
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Parameters to be optimized. |
required |
optimizer
|
GradientTransformation
|
TorchOpt functional optimizer. Make sure to use lower case e.g. torchopt.adam() |
required |
Returns:
Type | Description |
---|---|
TorchOptState
|
Initial TorchOptState. |
Source code in posteriors/torchopt.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
posteriors.torchopt.update(state, batch, loss_fn, optimizer, inplace=False)
𝞡
Update the TorchOpt optimizer state.
Make sure to use the lower case functional optimizers e.g. torchopt.adam()
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
TorchOptState
|
Current state. |
required |
batch
|
TensorTree
|
Batch of data. |
required |
loss_fn
|
LogProbFn
|
Loss function. |
required |
optimizer
|
GradientTransformation
|
TorchOpt functional optimizer. Make sure to use lower case like torchopt.adam() |
required |
inplace
|
bool
|
Whether to update the state in place. |
False
|
Returns:
Type | Description |
---|---|
TorchOptState
|
Updated TorchOptState. |
Source code in posteriors/torchopt.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
|