Skip to content

Types

posteriors.types.InitFn 𝞡

Bases: Protocol

Source code in posteriors/types.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class InitFn(Protocol):
    def __call__(
        self,
        params: TensorTree,
    ) -> TensorClass:
        """Initiate a posteriors state with unified API:

        ```
        state = init(params)
        ```

        where params is a PyTree of parameters. The produced `state` is a
        `tensordict.TensorClass` containing the required information for the
        posteriors iterative algorithm defined by the `init` and `update` functions.

        Note that this represents the `init` function as stored in a `Transform`
        returned by an algorithm's `build` function, the internal `init` function in
        the algorithm module can and likely will have additional arguments.

        Args:
            params: PyTree containing initial value of parameters.

        Returns:
            The initial state, a `tensordict.tensorclass` with `params` and `aux`
            attributes but possibly other attributes too.
        """
        ...  # pragma: no cover

__call__(params) 𝞡

Initiate a posteriors state with unified API:

state = init(params)

where params is a PyTree of parameters. The produced state is a tensordict.TensorClass containing the required information for the posteriors iterative algorithm defined by the init and update functions.

Note that this represents the init function as stored in a Transform returned by an algorithm's build function, the internal init function in the algorithm module can and likely will have additional arguments.

Parameters:

Name Type Description Default
params TensorTree

PyTree containing initial value of parameters.

required

Returns:

Type Description
TensorClass

The initial state, a tensordict.tensorclass with params and aux

TensorClass

attributes but possibly other attributes too.

Source code in posteriors/types.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __call__(
    self,
    params: TensorTree,
) -> TensorClass:
    """Initiate a posteriors state with unified API:

    ```
    state = init(params)
    ```

    where params is a PyTree of parameters. The produced `state` is a
    `tensordict.TensorClass` containing the required information for the
    posteriors iterative algorithm defined by the `init` and `update` functions.

    Note that this represents the `init` function as stored in a `Transform`
    returned by an algorithm's `build` function, the internal `init` function in
    the algorithm module can and likely will have additional arguments.

    Args:
        params: PyTree containing initial value of parameters.

    Returns:
        The initial state, a `tensordict.tensorclass` with `params` and `aux`
        attributes but possibly other attributes too.
    """
    ...  # pragma: no cover

posteriors.types.UpdateFn 𝞡

Bases: Protocol

Source code in posteriors/types.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class UpdateFn(Protocol):
    def __call__(
        self,
        state: TensorClass,
        batch: Any,
        inplace: bool = False,
    ) -> tuple[TensorClass, TensorTree]:
        """Transform a posteriors state with unified API:

        ```
        state, aux = update(state, batch, inplace=False)
        ```

        where state is a `tensordict.TensorClass` containing the required information
        for the posteriors iterative algorithm defined by the `init` and `update`
        functions. `aux` is an arbitrary info object returned by the
        `log_posterior` or `log_likelihood` function.

        Note that this represents the `update` function as stored in a `Transform`
        returned by an algorithm's `build` function, the internal `update` function in
        the algorithm module can and likely will have additional arguments.

        Args:
            state: The `tensordict.tensorclass` state of the iterative algorithm.
            batch: The data batch.
            inplace: Whether to modify state using inplace operations. Defaults to True.

        Returns:
            Tuple of `state` and `aux`.
                `state` is a `tensordict.tensorclass` with `params` attributes
                but possibly other attributes too. Must be of the same type as
                the input state.
                `aux` is an arbitrary info object returned by the
                `log_posterior` or `log_likelihood` function.
        """
        ...  # pragma: no cover

__call__(state, batch, inplace=False) 𝞡

Transform a posteriors state with unified API:

state, aux = update(state, batch, inplace=False)

where state is a tensordict.TensorClass containing the required information for the posteriors iterative algorithm defined by the init and update functions. aux is an arbitrary info object returned by the log_posterior or log_likelihood function.

Note that this represents the update function as stored in a Transform returned by an algorithm's build function, the internal update function in the algorithm module can and likely will have additional arguments.

Parameters:

Name Type Description Default
state TensorClass

The tensordict.tensorclass state of the iterative algorithm.

required
batch Any

The data batch.

required
inplace bool

Whether to modify state using inplace operations. Defaults to True.

False

Returns:

Type Description
tuple[TensorClass, TensorTree]

Tuple of state and aux. state is a tensordict.tensorclass with params attributes but possibly other attributes too. Must be of the same type as the input state. aux is an arbitrary info object returned by the log_posterior or log_likelihood function.

Source code in posteriors/types.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __call__(
    self,
    state: TensorClass,
    batch: Any,
    inplace: bool = False,
) -> tuple[TensorClass, TensorTree]:
    """Transform a posteriors state with unified API:

    ```
    state, aux = update(state, batch, inplace=False)
    ```

    where state is a `tensordict.TensorClass` containing the required information
    for the posteriors iterative algorithm defined by the `init` and `update`
    functions. `aux` is an arbitrary info object returned by the
    `log_posterior` or `log_likelihood` function.

    Note that this represents the `update` function as stored in a `Transform`
    returned by an algorithm's `build` function, the internal `update` function in
    the algorithm module can and likely will have additional arguments.

    Args:
        state: The `tensordict.tensorclass` state of the iterative algorithm.
        batch: The data batch.
        inplace: Whether to modify state using inplace operations. Defaults to True.

    Returns:
        Tuple of `state` and `aux`.
            `state` is a `tensordict.tensorclass` with `params` attributes
            but possibly other attributes too. Must be of the same type as
            the input state.
            `aux` is an arbitrary info object returned by the
            `log_posterior` or `log_likelihood` function.
    """
    ...  # pragma: no cover

posteriors.types.Transform 𝞡

Bases: NamedTuple

A transform contains init and update functions defining an iterative algorithm.

Within the Transform all algorithm specific arguments are predefined, so that the init and update functions have a unified API:

state = transform.init(params)
state, aux = transform.update(state, batch, inplace=False)

Note that this represents the Transform function is returned by an algorithm's build function, the internal init and update functions in the algorithm module can and likely will have additional arguments.

Attributes:

Name Type Description
init InitFn

The init function.

update UpdateFn

The update function.

Source code in posteriors/types.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Transform(NamedTuple):
    """A transform contains `init` and `update` functions defining an iterative
        algorithm.

    Within the `Transform` all algorithm specific arguments are predefined, so that the
    `init` and `update` functions have a unified API:
    ```
    state = transform.init(params)
    state, aux = transform.update(state, batch, inplace=False)
    ```

    Note that this represents the `Transform` function is returned by an algorithm's
    `build` function, the internal `init` and `update` functions in the
    algorithm module can and likely will have additional arguments.

    Attributes:
        init: The init function.
        update: The update function.

    """

    init: InitFn
    update: UpdateFn