Source code for typinox.tree

import jaxtyping
from beartype.typing import Generator, Iterable
from jax import (
    numpy as jnp,
    tree as jt,
)

from .vmapped import Vmapped

PyTreeDef = type[jaxtyping.PyTreeDef]  # to make mypy happy


[docs] def stack[T](trees: Iterable[T]) -> Vmapped[T, " n"]: """Stacks every corresponding leaf on a list of PyTrees. For example, given two trees ((a, b), c) and ((a', b'), c'), returns ((stack(a, a'), stack(b, b')), stack(c, c')). Useful for turning a list of objects into something you can feed to a vmapped function. Parameters ---------- T : type, type parameter The type of the PyTrees trees : Iterable[T] A list of PyTrees, all with the same structure ``T`` Returns ------- Vmapped[T, "n"] A PyTree with the same structure as the input trees, but with the leaves stacked Examples -------- >>> a = ({ "l": jnp.array([1, 2]), "r": jnp.array(3) }, ... jnp.array([4, 5, 6])) >>> b = ({ "l": jnp.array([7, 8]), "r": jnp.array(9) }, ... jnp.array([10, 11, 12])) >>> c = typinox.tree.stack([a, b]) >>> d = ({ "l": jnp.array([[1, 2], [7, 8]]), "r": jnp.array([3, 9]) }, ... jnp.array([[4, 5, 6], [10, 11, 12]])) >>> chex.assert_trees_all_equal(c, d) """ leaves_list = [] treedef_list: list[PyTreeDef] = [] for tree in trees: leaves, treedef = jt.flatten(tree) leaves_list.append(leaves) treedef_list.append(treedef) grouped_leaves = zip(*leaves_list) result_leaves = [jnp.stack(l) for l in grouped_leaves] return treedef_list[0].unflatten(result_leaves)
[docs] def unstack[T](tree: Vmapped[T, " _"]) -> Generator[T]: """Unstacks a PyTree of arrays into a list of PyTrees. Inverse of :func:`stack`. For example, given a tree ((a, b), c), where a, b, and c all have first dimension k, will make k trees [((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]. Useful for turning the output of a vmapped function into normal objects. Parameters ---------- T : type, type parameter The type of the PyTrees tree : Vmapped[T, " _"] A PyTree of structure ``T`` with the leaves stacked Yields ------ Generator[T] A generator that yields PyTrees with the same structure ``T`` as the input tree, but with the leaves unstacked Examples -------- >>> a = ({ "l": jnp.array([[1, 2], [3, 4]]), "r": jnp.array([5, 6]) }, ... jnp.array([7, 8])) >>> aa = list(typinox.tree.unstack(a)) >>> bb = [ ... ({ "l": jnp.array([1, 2]), "r": jnp.array(5) }, ... jnp.array(7)), ... ({ "l": jnp.array([3, 4]), "r": jnp.array(6) }, ... jnp.array(8)) ... ] >>> chex.assert_trees_all_equal(aa, bb) """ treedef: PyTreeDef # to make mypy happy leaves, treedef = jt.flatten(tree) n_trees = leaves[0].shape[0] for i in range(n_trees): new_leaves = [leaf[i] for leaf in leaves] yield treedef.unflatten(new_leaves)