Vmapped annotation

Vmapped annotation#

typinox.Vmapped#

alias of VmappedT

class typinox.VmappedT[source]#

Vmapped type hint.

When VmappedT[inner, shape] is used as a type hint, the object is expected to be a PyTree structure specified by inner, with each array of it extended by shape. For example, VmappedT[Linear, "device batch"] is a type hint for a Linear PyTree with each array shape prepended by two dimensions, named device and batch.

Shape

The shape is a space-separated string of dimension specifiers, such as "a 3". Each dimension can be one of the following:

  • A named dimension, such as a.

  • A fixed dimension, such as 3.

  • A symbolic expression in terms of other variable-size axes, such as a-1.

  • _: An anonymous dimension, which is a named dimension that is not matched to any other dimension.

They function similarly to the dimension specifiers in jaxtyping. It can also be:

  • $: Represents the original array shape before vmapping. This is optional, and if omitted, it is assumed that $ appears in the end.

    For example, if T is Float[Array, "3 4"], then Vmapped[T, "batch"] and Vmapped[T, "batch $"] are equivalent to Float[Array, "batch 3 4"]; Vmapped[T, "$ batch"] is equivalent to Float[Array, "3 4 batch"].

Other jaxtyping array annotation modifiers, such as *, ?, and ... are not supported.

Warning

None of the axes specified in shape can be zero (e.g., Vmapped[T, "0 3"] is invalid).

Inner type

The inner is the type hint for a PyTree structure of arrays. It can be any valid type hint, including another VmappedT type hint.

It is most useful when isinstance(x, inner) also checks the shape of arrays in x. Examples of such inner hints include:

  • (jax or numpy) arrays: isinstance(x, VmappedT[Array, "b c"]) checks that x is a jax array with at least two initial dimensions.

  • tuple of arrays: isinstance(x, VmappedT[tuple[Array, Array], "b c"]) checks that both elements of x are jax arrays with at least two initial dimensions, and that their first two dimensions match.

  • jaxtyping shaped arrays: VmappedT[Float[Array, "b c"], "a"] is equivalent to VmappedT[Array, "a b c"], while VmappedT[Float[Array, "b c"], "$ a"] is equivalent to VmappedT[Array, "b c a"].

  • Typinox Modules: isinstance(x, VmappedT[M, "b c"]) checks that x is an instance of a TypedModule class M, and each array member of x has at least two initial dimensions. If these dimensions are removed, the resulting object should match M.

class typinox.VmappedI[source]#

Vmapped type hint without using beartype.

Used only when you do not want to use is_bearable() to check the type of the object. Useful when the inner annotation is a jaxtyping array type hint.

The interface is the same as VmappedT.