typinox.tree module#
- typinox.tree.stack(trees: Iterable) VmappedT[T, n][source]#
Stacks every corresponding leaf on a list of PyTrees.
For example, given two trees ((a, b), c) and ((a’, b’), c’), returns ((stack(a, a’), stack(b, b’)), stack(c, c’)). Useful for turning a list of objects into something you can feed to a vmapped function.
- Parameters:
T (type, type parameter) – The type of the PyTrees
trees (Iterable[T]) – A list of PyTrees, all with the same structure
T
- Returns:
A PyTree with the same structure as the input trees, but with the leaves stacked
- Return type:
Vmapped[T, “n”]
Examples
>>> a = ({ "l": jnp.array([1, 2]), "r": jnp.array(3) }, ... jnp.array([4, 5, 6])) >>> b = ({ "l": jnp.array([7, 8]), "r": jnp.array(9) }, ... jnp.array([10, 11, 12])) >>> c = typinox.tree.stack([a, b]) >>> d = ({ "l": jnp.array([[1, 2], [7, 8]]), "r": jnp.array([3, 9]) }, ... jnp.array([[4, 5, 6], [10, 11, 12]])) >>> chex.assert_trees_all_equal(c, d)
- typinox.tree.unstack(tree: VmappedT[T, _]) Generator[source]#
Unstacks a PyTree of arrays into a list of PyTrees. Inverse of
stack().For example, given a tree ((a, b), c), where a, b, and c all have first dimension k, will make k trees [((a[0], b[0]), c[0]), …, ((a[k], b[k]), c[k])].
Useful for turning the output of a vmapped function into normal objects.
- Parameters:
T (type, type parameter) – The type of the PyTrees
tree (Vmapped[T, " _"]) – A PyTree of structure
Twith the leaves stacked
- Yields:
Generator[T] – A generator that yields PyTrees with the same structure
Tas the input tree, but with the leaves unstacked
Examples
>>> a = ({ "l": jnp.array([[1, 2], [3, 4]]), "r": jnp.array([5, 6]) }, ... jnp.array([7, 8])) >>> aa = list(typinox.tree.unstack(a)) >>> bb = [ ... ({ "l": jnp.array([1, 2]), "r": jnp.array(5) }, ... jnp.array(7)), ... ({ "l": jnp.array([3, 4]), "r": jnp.array(6) }, ... jnp.array(8)) ... ] >>> chex.assert_trees_all_equal(aa, bb)