Optim
posteriors.optim.build(loss_fn, optimizer, **kwargs)
𝞡
Builds an optimizer transform from torch.optim
transform = build(loss_fn, torch.optim.Adam, lr=0.1)
state = transform.init(params)
for batch in dataloader:
state, aux = transform.update(state, batch)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn
|
LogProbFn
|
Function that takes the parameters and returns the loss.
of the form |
required |
optimizer
|
Type[Optimizer]
|
Optimizer class from torch.optim. |
required |
**kwargs
|
Any
|
Keyword arguments to pass to the optimizer class. |
{}
|
Returns:
Type | Description |
---|---|
Transform
|
|
Source code in posteriors/optim.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
|
posteriors.optim.OptimState
𝞡
Bases: TensorClass['frozen']
State of an optimizer from torch.optim.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Parameters to be optimized. |
optimizer |
Optimizer
|
torch.optim optimizer instance. |
loss |
Tensor
|
Loss value. |
Source code in posteriors/optim.py
41 42 43 44 45 46 47 48 49 50 51 52 |
|
posteriors.optim.init(params, optimizer_cls, *args, **kwargs)
𝞡
Initialise a torch.optim optimizer state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Parameters to be optimized. |
required |
optimizer_cls
|
Type[Optimizer]
|
Optimizer class from torch.optim. |
required |
*args
|
Any
|
Positional arguments to pass to the optimizer class. |
()
|
**kwargs
|
Any
|
Keyword arguments to pass to the optimizer class. |
{}
|
Returns:
Type | Description |
---|---|
OptimState
|
Initial OptimState. |
Source code in posteriors/optim.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
|
posteriors.optim.update(state, batch, loss_fn, inplace=True)
𝞡
Perform a single update step of a torch.optim optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
OptimState
|
Current optimizer state. |
required |
batch
|
TensorTree
|
Input data to loss_fn. |
required |
loss_fn
|
LogProbFn
|
Function that takes the parameters and returns the loss.
of the form |
required |
inplace
|
bool
|
Whether to update the parameters in place. inplace=False not supported for posteriors.optim |
True
|
Returns:
Type | Description |
---|---|
tuple[OptimState, TensorTree]
|
Updated OptimState and auxiliary information. |
Source code in posteriors/optim.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|