Typinox#
Typinox (TAI-pee-nox) is a Python library for enhancing run-time type-checking of
jaxtyping-annotated arrays and equinox.Module s.
Note
Typinox is currently in very early stages and is not yet ready for general use. The documentation is also a work in progress.
Installation#
To use Typinox, first install it using pip:
$ pip install typinox
Python 3.12 or later is required.
Basic usage#
Typinox has two main components: typinox.vmapped,
providing a Vmapped annotation
for vmap-compatible functions,
and typinox.module,
providing a TypedModule class for run-time type-checking of Equinox modules.
With Vmapped[T, "dims"] you can annotate variables that are valid arguments to a
vmap-ed function. For example, if
T = tuple[Float[Array, "3"], Float[Array, "2"]]
(jnp.zeros(3), jnp.ones(2)) : T
Then Vmapped[T, "n"] is equivalent to tuple[Float[Array, "n 3"], Float[Array, "n 2"]]:
assert isinstance((jnp.zeros((4, 3)), jnp.ones((4, 2))), Vmapped[T, "n"])
And you can use it with vmap():
def my_function(_) -> T:
return (jnp.zeros(3), jnp.ones(2))
a = jax.vmap(my_function)(jnp.arange(4))
assert isinstance(a, Vmapped[T, "n"])
TypedModule is an extension of equinox.Module
with automatic type-checking. It uses jaxtyped()
and beartype() to check method calls and return values.
class AffineMap(TypedModule):
k: Float[Array, "n m"]
b: Float[Array, "n"]
def __call__(self: Self, x: Float[Array, "m"]) -> Float[Array, "n"]:
return jnp.dot(self.k, x) + self.b
f = AffineMap(k=jnp.ones((3, 2)).astype(float), b=jnp.zeros(3))
TypedModules are designed to
work with Vmapped perfectly. If f() returns
AffineMap, then jax.vmap(f) returns Vmapped[AffineMap, "n"].
Check out the Basic Usage section for further information.
Dependencies#
Typinox aggressively tracks the latest versions of its dependencies. It currently depends on:
Python 3.12 (for PEP 695 syntax)
beartype0.20.0jaxtyping0.2.38equinox0.11.12
Typinox may drop support for older versions of these dependencies if newer ones provide any benefits that Typinox can leverage.