Skip to content

Tree Utils

posteriors.tree_utils.tree_size(tree) 𝞡

Returns the total number of elements in a PyTree. Not the number of leaves, but the total number of elements for all tensors in the tree.

Parameters:

Name Type Description Default
tree TensorTree

A PyTree of tensors.

required

Returns:

Type Description
int

Number of elements in the PyTree.

Source code in posteriors/tree_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def tree_size(tree: TensorTree) -> int:
    """Returns the total number of elements in a PyTree.
    Not the number of leaves, but the total number of elements for all tensors in the
    tree.

    Args:
        tree: A PyTree of tensors.

    Returns:
        Number of elements in the PyTree.
    """

    def ensure_tensor(x):
        return x if isinstance(x, torch.Tensor) else torch.tensor(x)

    return tree_reduce(torch.add, tree_map(lambda x: ensure_tensor(x).numel(), tree))

posteriors.tree_utils.tree_extract(tree, f) 𝞡

Extracts values from a PyTree where f returns True. False values are replaced with empty tensors.

Parameters:

Name Type Description Default
tree TensorTree

A PyTree.

required
f Callable[[tensor], bool]

A function that takes a PyTree element and returns True or False.

required

Returns:

Type Description
TensorTree

A PyTree with the same structure as tree where f returns True.

Source code in posteriors/tree_utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def tree_extract(
    tree: TensorTree,
    f: Callable[[torch.tensor], bool],
) -> TensorTree:
    """Extracts values from a PyTree where f returns True.
    False values are replaced with empty tensors.

    Args:
        tree: A PyTree.
        f: A function that takes a PyTree element and returns True or False.

    Returns:
        A PyTree with the same structure as tree where f returns True.
    """
    return tree_map(lambda x: x if f(x) else torch.tensor([], device=x.device), tree)

posteriors.tree_utils.tree_insert(full_tree, sub_tree, f=lambda _: True) 𝞡

Inserts sub_tree into full_tree where full_tree tensors evaluate f to True. Both PyTrees must have the same structure.

Parameters:

Name Type Description Default
full_tree TensorTree

A PyTree to insert sub_tree into.

required
sub_tree TensorTree

A PyTree to insert into full_tree.

required
f Callable[[tensor], bool]

A function that takes a PyTree element and returns True or False. Defaults to lambda _: True. I.e. insert on all leaves.

lambda _: True

Returns:

Type Description
TensorTree

A PyTree with sub_tree inserted into full_tree.

Source code in posteriors/tree_utils.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def tree_insert(
    full_tree: TensorTree,
    sub_tree: TensorTree,
    f: Callable[[torch.tensor], bool] = lambda _: True,
) -> TensorTree:
    """Inserts sub_tree into full_tree where full_tree tensors evaluate f to True.
    Both PyTrees must have the same structure.

    Args:
        full_tree: A PyTree to insert sub_tree into.
        sub_tree: A PyTree to insert into full_tree.
        f: A function that takes a PyTree element and returns True or False.
            Defaults to lambda _: True. I.e. insert on all leaves.

    Returns:
        A PyTree with sub_tree inserted into full_tree.
    """
    return tree_map(
        lambda sub, full: sub if f(full) else full,
        sub_tree,
        full_tree,
    )

posteriors.tree_utils.tree_insert_(full_tree, sub_tree, f=lambda _: True) 𝞡

Inserts sub_tree into full_tree in-place where full_tree tensors evaluate f to True. Both PyTrees must have the same structure.

Parameters:

Name Type Description Default
full_tree TensorTree

A PyTree to insert sub_tree into.

required
sub_tree TensorTree

A PyTree to insert into full_tree.

required
f Callable[[tensor], bool]

A function that takes a PyTree element and returns True or False. Defaults to lambda _: True. I.e. insert on all leaves.

lambda _: True

Returns:

Type Description
TensorTree

A pointer to full_tree with sub_tree inserted.

Source code in posteriors/tree_utils.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def tree_insert_(
    full_tree: TensorTree,
    sub_tree: TensorTree,
    f: Callable[[torch.tensor], bool] = lambda _: True,
) -> TensorTree:
    """Inserts sub_tree into full_tree in-place where full_tree tensors evaluate
    f to True. Both PyTrees must have the same structure.

    Args:
        full_tree: A PyTree to insert sub_tree into.
        sub_tree: A PyTree to insert into full_tree.
        f: A function that takes a PyTree element and returns True or False.
            Defaults to lambda _: True. I.e. insert on all leaves.

    Returns:
        A pointer to full_tree with sub_tree inserted.
    """

    def insert_(full, sub):
        if f(full):
            full.data = sub.data

    return tree_map_(insert_, full_tree, sub_tree)

posteriors.tree_utils.extract_requires_grad(tree) 𝞡

Extracts only parameters that require gradients.

Parameters:

Name Type Description Default
tree TensorTree

A PyTree of tensors.

required

Returns:

Type Description
TensorTree

A PyTree of tensors that require gradients.

Source code in posteriors/tree_utils.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def extract_requires_grad(tree: TensorTree) -> TensorTree:
    """Extracts only parameters that require gradients.

    Args:
        tree: A PyTree of tensors.

    Returns:
        A PyTree of tensors that require gradients.
    """
    return tree_extract(tree, lambda x: x.requires_grad)

posteriors.tree_utils.insert_requires_grad(full_tree, sub_tree) 𝞡

Inserts sub_tree into full_tree where full_tree tensors requires_grad. Both PyTrees must have the same structure.

Parameters:

Name Type Description Default
full_tree TensorTree

A PyTree to insert sub_tree into.

required
sub_tree TensorTree

A PyTree to insert into full_tree.

required

Returns:

Type Description
TensorTree

A PyTree with sub_tree inserted into full_tree.

Source code in posteriors/tree_utils.py
104
105
106
107
108
109
110
111
112
113
114
115
def insert_requires_grad(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree:
    """Inserts sub_tree into full_tree where full_tree tensors requires_grad.
    Both PyTrees must have the same structure.

    Args:
        full_tree: A PyTree to insert sub_tree into.
        sub_tree: A PyTree to insert into full_tree.

    Returns:
        A PyTree with sub_tree inserted into full_tree.
    """
    return tree_insert(full_tree, sub_tree, lambda x: x.requires_grad)

posteriors.tree_utils.insert_requires_grad_(full_tree, sub_tree) 𝞡

Inserts sub_tree into full_tree in-place where full_tree tensors requires_grad. Both PyTrees must have the same structure.

Parameters:

Name Type Description Default
full_tree TensorTree

A PyTree to insert sub_tree into.

required
sub_tree TensorTree

A PyTree to insert into full_tree.

required

Returns:

Type Description
TensorTree

A pointer to full_tree with sub_tree inserted.

Source code in posteriors/tree_utils.py
118
119
120
121
122
123
124
125
126
127
128
129
def insert_requires_grad_(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree:
    """Inserts sub_tree into full_tree in-place where full_tree tensors requires_grad.
    Both PyTrees must have the same structure.

    Args:
        full_tree: A PyTree to insert sub_tree into.
        sub_tree: A PyTree to insert into full_tree.

    Returns:
        A pointer to full_tree with sub_tree inserted.
    """
    return tree_insert_(full_tree, sub_tree, lambda x: x.requires_grad)

posteriors.tree_utils.extract_requires_grad_and_func(tree, func, inplace=False) 𝞡

Extracts only parameters that require gradients and converts a function that takes the full parameter tree (in its first argument) into one that takes the subtree.

Parameters:

Name Type Description Default
tree TensorTree

A PyTree of tensors.

required
func Callable

A function that takes tree in its first argument.

required
inplace bool

Whether to modify the tree inplace or not whe the new function is called.

False

Returns:

Type Description
Tuple[TensorTree, Callable]

A PyTree of tensors that require gradients and a modified func that takes the subtree structure rather than full tree in its first argument.

Source code in posteriors/tree_utils.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def extract_requires_grad_and_func(
    tree: TensorTree, func: Callable, inplace: bool = False
) -> Tuple[TensorTree, Callable]:
    """Extracts only parameters that require gradients and converts a function
    that takes the full parameter tree (in its first argument)
    into one that takes the subtree.

    Args:
        tree: A PyTree of tensors.
        func: A function that takes tree in its first argument.
        inplace: Whether to modify the tree inplace or not whe the new function
            is called.

    Returns:
        A PyTree of tensors that require gradients and a modified func that takes the
            subtree structure rather than full tree in its first argument.
    """
    subtree = extract_requires_grad(tree)

    insert = insert_requires_grad_ if inplace else insert_requires_grad

    def subfunc(subtree, *args, **kwargs):
        return func(insert(tree, subtree), *args, **kwargs)

    return subtree, subfunc

posteriors.tree_utils.inplacify(func) 𝞡

Converts a function that takes a tensor as its first argument into one that takes the same arguments but modifies the first argument tensor in-place with the output of the function.

Parameters:

Name Type Description Default
func Callable

A function that takes a tensor as its first argument and a returns a modified version of said tensor.

required

Returns:

Type Description
Callable

A function that takes a tensor as its first argument and modifies it in-place.

Source code in posteriors/tree_utils.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def inplacify(func: Callable) -> Callable:
    """Converts a function that takes a tensor as its first argument
    into one that takes the same arguments but modifies the first argument
    tensor in-place with the output of the function.

    Args:
        func: A function that takes a tensor as its first argument and a returns
            a modified version of said tensor.

    Returns:
        A function that takes a tensor as its first argument and modifies it
            in-place.
    """

    def func_(tens, *args, **kwargs):
        tens.data = func(tens, *args, **kwargs)
        return tens

    return func_

posteriors.tree_utils.tree_map_inplacify_(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace='') 𝞡

Applies a pure function to each tensor in a PyTree in-place.

Like optree.tree_map_ but takes a pure function as input (and takes replaces its first argument with its output in-place) rather than a side-effect function.

Parameters:

Name Type Description Default
func Callable

A function that takes a tensor as its first argument and a returns a modified version of said tensor.

required
tree pytree

A pytree to be mapped over, with each leaf providing the first positional argument to function func.

required
rests tuple of pytree

A tuple of pytrees, each of which has the same structure as tree or has tree as a prefix.

()
is_leaf callable

An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

None
none_is_leaf bool

Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will be remain in the result pytree. (default: False)

False
namespace str

The registry namespace used for custom pytree node types. (default: :const:'', i.e., the global namespace)

''

Returns:

Type Description
TensorTree

The original tree with the value at each leaf is given by the side-effect of function func(x, *xs) (not the return value) where x is the value at the corresponding leaf in tree and xs is the tuple of values at values at corresponding nodes in rests.

Source code in posteriors/tree_utils.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def tree_map_inplacify_(
    func: Callable,
    tree: TensorTree,
    *rests: TensorTree,
    is_leaf: Callable[[TensorTree], bool] | None = None,
    none_is_leaf: bool = False,
    namespace: str = "",
) -> TensorTree:
    """Applies a pure function to each tensor in a PyTree in-place.

    Like [optree.tree_map_](https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map_)
    but takes a pure function as input (and takes replaces its first argument with its
    output in-place) rather than a side-effect function.

    Args:
        func: A function that takes a tensor as its first argument and a returns
            a modified version of said tensor.
        tree (pytree): A pytree to be mapped over, with each leaf providing the first
            positional argument to function ``func``.
        rests (tuple of pytree): A tuple of pytrees, each of which has the same
            structure as ``tree`` or has ``tree`` as a prefix.
        is_leaf (callable, optional): An optionally specified function that will be
            called at each flattening step. It should return a boolean, with
            `True` stopping the traversal and the whole subtree being treated as a
            leaf, and `False` indicating the flattening should traverse the
            current object.
        none_is_leaf (bool, optional): Whether to treat `None` as a leaf. If
            `False`, `None` is a non-leaf node with arity 0. Thus `None` is contained in
            the treespec rather than in the leaves list and `None` will be remain in the
            result pytree. (default: `False`)
        namespace (str, optional): The registry namespace used for custom pytree node
            types. (default: :const:`''`, i.e., the global namespace)

    Returns:
        The original ``tree`` with the value at each leaf is given by the side-effect of
            function ``func(x, *xs)`` (not the return value) where ``x`` is the value at
            the corresponding leaf in ``tree`` and ``xs`` is the tuple of values at
            values at corresponding nodes in ``rests``.
    """
    return tree_map_(
        inplacify(func),
        tree,
        *rests,
        is_leaf=is_leaf,
        none_is_leaf=none_is_leaf,
        namespace=namespace,
    )

posteriors.tree_utils.flexi_tree_map(func, tree, *rests, inplace=False, is_leaf=None, none_is_leaf=False, namespace='') 𝞡

Applies a pure function to each tensor in a PyTree, with inplace argument.

out_tensor = func(tensor, *rest_tensors)

where out_tensor is of the same shape as tensor. Therefore

out_tree = func(tree, *rests, inplace=True)

will return out_tree a pointer to the original tree with leaves (tensors) modified in place. If inplace=False, flexi_tree_map is equivalent to optree.tree_map and returns a new tree.

Parameters:

Name Type Description Default
func Callable

A pure function that takes a tensor as its first argument and a returns a modified version of said tensor.

required
tree pytree

A pytree to be mapped over, with each leaf providing the first positional argument to function func.

required
rests tuple of pytree

A tuple of pytrees, each of which has the same structure as tree or has tree as a prefix.

()
inplace bool

Whether to modify the tree in-place or not.

False
is_leaf callable

An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

None
none_is_leaf bool

Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will be remain in the result pytree. (default: False)

False
namespace str

The registry namespace used for custom pytree node types. (default: :const:'', i.e., the global namespace)

''

Returns:

Type Description
TensorTree

Either the original tree modified in-place or a new tree depending on the inplace argument.

Source code in posteriors/tree_utils.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def flexi_tree_map(
    func: Callable,
    tree: TensorTree,
    *rests: TensorTree,
    inplace: bool = False,
    is_leaf: Callable[[TensorTree], bool] | None = None,
    none_is_leaf: bool = False,
    namespace: str = "",
) -> TensorTree:
    """Applies a pure function to each tensor in a PyTree, with inplace argument.

    ```
    out_tensor = func(tensor, *rest_tensors)
    ```

    where `out_tensor` is of the same shape as `tensor`.
    Therefore

    ```
    out_tree = func(tree, *rests, inplace=True)
    ```

    will return `out_tree` a pointer to the original `tree` with leaves (tensors)
    modified in place.
    If `inplace=False`, `flexi_tree_map` is equivalent to [`optree.tree_map`](https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map)
    and returns a new tree.

    Args:
        func: A pure function that takes a tensor as its first argument and a returns
            a modified version of said tensor.
        tree (pytree): A pytree to be mapped over, with each leaf providing the first
            positional argument to function ``func``.
        rests (tuple of pytree): A tuple of pytrees, each of which has the same
            structure as ``tree`` or has ``tree`` as a prefix.
        inplace (bool, optional): Whether to modify the tree in-place or not.
        is_leaf (callable, optional): An optionally specified function that will be
            called at each flattening step. It should return a boolean, with `True`
            stopping the traversal and the whole subtree being treated as a leaf, and
            `False` indicating the flattening should traverse the current object.
        none_is_leaf (bool, optional): Whether to treat `None` as a leaf. If `False`,
            `None` is a non-leaf node with arity 0. Thus `None` is contained in the
            treespec rather than in the leaves list and `None` will be remain in the
            result pytree. (default: `False`)
        namespace (str, optional): The registry namespace used for custom pytree node
            types. (default: :const:`''`, i.e., the global namespace)

    Returns:
        Either the original tree modified in-place or a new tree depending on the
            `inplace` argument.
    """
    tm = tree_map_inplacify_ if inplace else tree_map
    return tm(
        func,
        tree,
        *rests,
        is_leaf=is_leaf,
        none_is_leaf=none_is_leaf,
        namespace=namespace,
    )