Tree Utils
posteriors.tree_utils.tree_size(tree)
𝞡
Returns the total number of elements in a PyTree. Not the number of leaves, but the total number of elements for all tensors in the tree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
TensorTree
|
A PyTree of tensors. |
required |
Returns:
Type | Description |
---|---|
int
|
Number of elements in the PyTree. |
Source code in posteriors/tree_utils.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
|
posteriors.tree_utils.tree_extract(tree, f)
𝞡
Extracts values from a PyTree where f returns True. False values are replaced with empty tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
TensorTree
|
A PyTree. |
required |
f
|
Callable[[tensor], bool]
|
A function that takes a PyTree element and returns True or False. |
required |
Returns:
Type | Description |
---|---|
TensorTree
|
A PyTree with the same structure as tree where f returns True. |
Source code in posteriors/tree_utils.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
|
posteriors.tree_utils.tree_insert(full_tree, sub_tree, f=lambda _: True)
𝞡
Inserts sub_tree into full_tree where full_tree tensors evaluate f to True. Both PyTrees must have the same structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
full_tree
|
TensorTree
|
A PyTree to insert sub_tree into. |
required |
sub_tree
|
TensorTree
|
A PyTree to insert into full_tree. |
required |
f
|
Callable[[tensor], bool]
|
A function that takes a PyTree element and returns True or False. Defaults to lambda _: True. I.e. insert on all leaves. |
lambda _: True
|
Returns:
Type | Description |
---|---|
TensorTree
|
A PyTree with sub_tree inserted into full_tree. |
Source code in posteriors/tree_utils.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
posteriors.tree_utils.tree_insert_(full_tree, sub_tree, f=lambda _: True)
𝞡
Inserts sub_tree into full_tree in-place where full_tree tensors evaluate f to True. Both PyTrees must have the same structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
full_tree
|
TensorTree
|
A PyTree to insert sub_tree into. |
required |
sub_tree
|
TensorTree
|
A PyTree to insert into full_tree. |
required |
f
|
Callable[[tensor], bool]
|
A function that takes a PyTree element and returns True or False. Defaults to lambda _: True. I.e. insert on all leaves. |
lambda _: True
|
Returns:
Type | Description |
---|---|
TensorTree
|
A pointer to full_tree with sub_tree inserted. |
Source code in posteriors/tree_utils.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
posteriors.tree_utils.extract_requires_grad(tree)
𝞡
Extracts only parameters that require gradients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
TensorTree
|
A PyTree of tensors. |
required |
Returns:
Type | Description |
---|---|
TensorTree
|
A PyTree of tensors that require gradients. |
Source code in posteriors/tree_utils.py
92 93 94 95 96 97 98 99 100 101 |
|
posteriors.tree_utils.insert_requires_grad(full_tree, sub_tree)
𝞡
Inserts sub_tree into full_tree where full_tree tensors requires_grad. Both PyTrees must have the same structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
full_tree
|
TensorTree
|
A PyTree to insert sub_tree into. |
required |
sub_tree
|
TensorTree
|
A PyTree to insert into full_tree. |
required |
Returns:
Type | Description |
---|---|
TensorTree
|
A PyTree with sub_tree inserted into full_tree. |
Source code in posteriors/tree_utils.py
104 105 106 107 108 109 110 111 112 113 114 115 |
|
posteriors.tree_utils.insert_requires_grad_(full_tree, sub_tree)
𝞡
Inserts sub_tree into full_tree in-place where full_tree tensors requires_grad. Both PyTrees must have the same structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
full_tree
|
TensorTree
|
A PyTree to insert sub_tree into. |
required |
sub_tree
|
TensorTree
|
A PyTree to insert into full_tree. |
required |
Returns:
Type | Description |
---|---|
TensorTree
|
A pointer to full_tree with sub_tree inserted. |
Source code in posteriors/tree_utils.py
118 119 120 121 122 123 124 125 126 127 128 129 |
|
posteriors.tree_utils.extract_requires_grad_and_func(tree, func, inplace=False)
𝞡
Extracts only parameters that require gradients and converts a function that takes the full parameter tree (in its first argument) into one that takes the subtree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
TensorTree
|
A PyTree of tensors. |
required |
func
|
Callable
|
A function that takes tree in its first argument. |
required |
inplace
|
bool
|
Whether to modify the tree inplace or not whe the new function is called. |
False
|
Returns:
Type | Description |
---|---|
Tuple[TensorTree, Callable]
|
A PyTree of tensors that require gradients and a modified func that takes the subtree structure rather than full tree in its first argument. |
Source code in posteriors/tree_utils.py
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
posteriors.tree_utils.inplacify(func)
𝞡
Converts a function that takes a tensor as its first argument into one that takes the same arguments but modifies the first argument tensor in-place with the output of the function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func
|
Callable
|
A function that takes a tensor as its first argument and a returns a modified version of said tensor. |
required |
Returns:
Type | Description |
---|---|
Callable
|
A function that takes a tensor as its first argument and modifies it in-place. |
Source code in posteriors/tree_utils.py
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
|
posteriors.tree_utils.tree_map_inplacify_(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace='')
𝞡
Applies a pure function to each tensor in a PyTree in-place.
Like optree.tree_map_ but takes a pure function as input (and takes replaces its first argument with its output in-place) rather than a side-effect function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func
|
Callable
|
A function that takes a tensor as its first argument and a returns a modified version of said tensor. |
required |
tree
|
pytree
|
A pytree to be mapped over, with each leaf providing the first
positional argument to function |
required |
rests
|
tuple of pytree
|
A tuple of pytrees, each of which has the same
structure as |
()
|
is_leaf
|
callable
|
An optionally specified function that will be
called at each flattening step. It should return a boolean, with
|
None
|
none_is_leaf
|
bool
|
Whether to treat |
False
|
namespace
|
str
|
The registry namespace used for custom pytree node
types. (default: :const: |
''
|
Returns:
Type | Description |
---|---|
TensorTree
|
The original |
Source code in posteriors/tree_utils.py
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
|
posteriors.tree_utils.flexi_tree_map(func, tree, *rests, inplace=False, is_leaf=None, none_is_leaf=False, namespace='')
𝞡
Applies a pure function to each tensor in a PyTree, with inplace argument.
out_tensor = func(tensor, *rest_tensors)
where out_tensor
is of the same shape as tensor
.
Therefore
out_tree = func(tree, *rests, inplace=True)
will return out_tree
a pointer to the original tree
with leaves (tensors)
modified in place.
If inplace=False
, flexi_tree_map
is equivalent to optree.tree_map
and returns a new tree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func
|
Callable
|
A pure function that takes a tensor as its first argument and a returns a modified version of said tensor. |
required |
tree
|
pytree
|
A pytree to be mapped over, with each leaf providing the first
positional argument to function |
required |
rests
|
tuple of pytree
|
A tuple of pytrees, each of which has the same
structure as |
()
|
inplace
|
bool
|
Whether to modify the tree in-place or not. |
False
|
is_leaf
|
callable
|
An optionally specified function that will be
called at each flattening step. It should return a boolean, with |
None
|
none_is_leaf
|
bool
|
Whether to treat |
False
|
namespace
|
str
|
The registry namespace used for custom pytree node
types. (default: :const: |
''
|
Returns:
Type | Description |
---|---|
TensorTree
|
Either the original tree modified in-place or a new tree depending on the
|
Source code in posteriors/tree_utils.py
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
|