Vmapped annotation#
- class typinox.VmappedT[source]#
Vmapped type hint.
When
VmappedT[inner, shape]is used as a type hint, the object is expected to be a PyTree structure specified byinner, with each array of it extended byshape. For example,VmappedT[Linear, "device batch"]is a type hint for aLinearPyTree with each array shape prepended by two dimensions, nameddeviceandbatch.Shape
The
shapeis a space-separated string of dimension specifiers, such as"a 3". Each dimension can be one of the following:A named dimension, such as
a.A fixed dimension, such as
3.A symbolic expression in terms of other variable-size axes, such as
a-1._: An anonymous dimension, which is a named dimension that is not matched to any other dimension.
They function similarly to the dimension specifiers in
jaxtyping. It can also be:$: Represents the original array shape before vmapping. This is optional, and if omitted, it is assumed that$appears in the end.For example, if
TisFloat[Array, "3 4"], thenVmapped[T, "batch"]andVmapped[T, "batch $"]are equivalent toFloat[Array, "batch 3 4"];Vmapped[T, "$ batch"]is equivalent toFloat[Array, "3 4 batch"].
Other jaxtyping array annotation modifiers, such as
*,?, and...are not supported.Warning
None of the axes specified in
shapecan be zero (e.g.,Vmapped[T, "0 3"]is invalid).Inner type
The
inneris the type hint for a PyTree structure of arrays. It can be any valid type hint, including anotherVmappedTtype hint.It is most useful when
isinstance(x, inner)also checks the shape of arrays inx. Examples of suchinnerhints include:(jax or numpy) arrays:
isinstance(x, VmappedT[Array, "b c"])checks thatxis a jax array with at least two initial dimensions.tuple of arrays:
isinstance(x, VmappedT[tuple[Array, Array], "b c"])checks that both elements ofxare jax arrays with at least two initial dimensions, and that their first two dimensions match.jaxtyping shaped arrays:
VmappedT[Float[Array, "b c"], "a"]is equivalent toVmappedT[Array, "a b c"], whileVmappedT[Float[Array, "b c"], "$ a"]is equivalent toVmappedT[Array, "b c a"].Typinox Modules:
isinstance(x, VmappedT[M, "b c"])checks thatxis an instance of aTypedModuleclassM, and each array member ofxhas at least two initial dimensions. If these dimensions are removed, the resulting object should matchM.
- class typinox.VmappedI[source]#
Vmapped type hint without using
beartype.Used only when you do not want to use
is_bearable()to check the type of the object. Useful when the inner annotation is ajaxtypingarray type hint.The interface is the same as
VmappedT.