Optim
posteriors.optim.build(loss_fn, optimizer, **kwargs)
𝞡
Builds an optimizer transform from torch.optim
transform = build(loss_fn, torch.optim.Adam, lr=0.1)
state = transform.init(params)
for batch in dataloader:
state = transform.update(state, batch)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn
|
LogProbFn
|
Function that takes the parameters and returns the loss.
of the form |
required |
optimizer
|
Type[Optimizer]
|
Optimizer class from torch.optim. |
required |
**kwargs
|
Any
|
Keyword arguments to pass to the optimizer class. |
{}
|
Returns:
Type | Description |
---|---|
Transform
|
|
Source code in posteriors/optim.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
posteriors.optim.OptimState
𝞡
Bases: NamedTuple
State of an optimizer from torch.optim.
Attributes:
Name | Type | Description |
---|---|---|
params |
TensorTree
|
Parameters to be optimized. |
optimizer |
Optimizer
|
torch.optim optimizer instance. |
loss |
tensor
|
Loss value. |
aux |
Any
|
Auxiliary information from the loss function call. |
Source code in posteriors/optim.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
|
posteriors.optim.init(params, optimizer_cls, *args, **kwargs)
𝞡
Initialise a torch.optim optimizer state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
TensorTree
|
Parameters to be optimized. |
required |
optimizer_cls
|
Type[Optimizer]
|
Optimizer class from torch.optim. |
required |
*args
|
Any
|
Positional arguments to pass to the optimizer class. |
()
|
**kwargs
|
Any
|
Keyword arguments to pass to the optimizer class. |
{}
|
Returns:
Type | Description |
---|---|
OptimState
|
Initial OptimState. |
Source code in posteriors/optim.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
posteriors.optim.update(state, batch, loss_fn, inplace=True)
𝞡
Perform a single update step of a torch.optim optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
OptimState
|
Current optimizer state. |
required |
batch
|
TensorTree
|
Input data to loss_fn. |
required |
loss_fn
|
LogProbFn
|
Function that takes the parameters and returns the loss.
of the form |
required |
inplace
|
bool
|
Whether to update the parameters in place. inplace=False not supported for posteriors.optim |
True
|
Returns:
Type | Description |
---|---|
OptimState
|
Updated OptimState. |
Source code in posteriors/optim.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|