Typinox#
Typinox (TAI-pee-nox) is a Python library for enhancing run-time type-checking of
jaxtyping
-annotated arrays and equinox.Module
s.
Note
Typinox is currently in very early stages and is not yet ready for general use. The documentation is also a work in progress.
Installation#
To use Typinox, first install it using pip:
$ pip install typinox
Python 3.12 or later is required.
Basic usage#
Typinox has two main components: typinox.vmapped
,
providing a Vmapped
annotation
for vmap
-compatible functions,
and typinox.module
,
providing a TypedModule
class for run-time type-checking of Equinox modules.
With Vmapped[T, "dims"]
you can annotate variables that are valid arguments to a
vmap
-ed function. For example, if
T = tuple[Float[Array, "3"], Float[Array, "2"]]
(jnp.zeros(3), jnp.ones(2)) : T
Then Vmapped[T, "n"]
is equivalent to tuple[Float[Array, "n 3"], Float[Array, "n 2"]]
:
assert isinstance((jnp.zeros((4, 3)), jnp.ones((4, 2))), Vmapped[T, "n"])
And you can use it with vmap()
:
def my_function(_) -> T:
return (jnp.zeros(3), jnp.ones(2))
a = jax.vmap(my_function)(jnp.arange(4))
assert isinstance(a, Vmapped[T, "n"])
TypedModule
is an extension of equinox.Module
with automatic type-checking. It uses jaxtyped()
and beartype()
to check method calls and return values.
class AffineMap(TypedModule):
k: Float[Array, "n m"]
b: Float[Array, "n"]
def __call__(self: Self, x: Float[Array, "m"]) -> Float[Array, "n"]:
return jnp.dot(self.k, x) + self.b
f = AffineMap(k=jnp.ones((3, 2)).astype(float), b=jnp.zeros(3))
TypedModule
s are designed to
work with Vmapped
perfectly. If f()
returns
AffineMap
, then jax.vmap(f)
returns Vmapped[AffineMap, "n"]
.
Check out the Basic Usage section for further information.
Dependencies#
Typinox aggressively tracks the latest versions of its dependencies. It currently depends on:
Python 3.12 (for PEP 695 syntax)
beartype
0.20.0jaxtyping
0.2.38equinox
0.11.12
Typinox may drop support for older versions of these dependencies if newer ones provide any benefits that Typinox can leverage.