Typed Modules#

class typinox.TypedModule[source]#

Base class. Inherit this to create a typinox typed module.

This class is a subclass of equinox.Module and provides automatic type validation for its fields.

Thanks to equinox.Module, every typed module is automatically a dataclass and a pytree.

Fields

Declare typed fields at the class level, using the same syntax as for dataclass. The class will automatically become a pytree containing all the fields. It will also automatically validate the types of the fields at runtime.

from jaxtyping import Array, Float, Key

class MyModule(TypedModule):
    weight: Float[Array, "m n"]
    bias: Float[Array, "m"]
    sublayers: list[TypedModule]

Initialization

A default constructor is provided similar to a dataclass. It fills each field with the arguments in order. You can also pass keyword arguments to fill the fields of the same name. For example MyModule(w, sublayers=[m1, m2], bias=b).

Alternatively, you can provide an __init__ method to customize the initialization behavior.

from jaxtyping import Array, Float, Key

class MyModule(TypedModule):
    weight: Float[Array, "m n"]
    bias: Float[Array, "m"]
    sublayers: list[TypedModule]

    def __init__(self, m: int, n: int,
                 key: Key[Array, ""],
                 sublayers: list[TypedModule] = [],
                ):
        self.weight = jax.random.normal(key, (m, n))
        self.bias = jax.ones((m,))
        self.sublayers = sublayers

Methods

Define methods at the class level just like any other class. Every methods is automatically wrapped with beartype.beartype() and jaxtyping.jaxtyped() to perform run-time type checking.

class MyModule(TypedModule):
    # ... same as above

    def __call__(self, x: Float[Array, "n"]) -> Float[Array, "m"]:
        y = jnp.dot(self.weight, x) + self.bias
        for layer in self.sublayers:
            y = layer(y)
        # if y is not a Float[Array, "m"] at this point,
        # an error will be raised by beartype
        return y

Tip

The method does not have to be named __call__(); it might as well be named forward(). The dunder name __call__ is not specially treated by typinox.

_validate() None[source]#

A helper method to validate the type of the module.

This is particularly useful for modules that may be vmapped. The argument to and return value of jax.vmap() are pytrees of arrays with added dimensions, and thus may be invalid modules. For example:

def create_module(key: Key[Array, ""]):
    return MyModule(n=3, m=2, key=key, sublayers=[])

key = jax.random.key(1)
some_keys = jax.random.split(key, 5)
some_modules = jax.vmap(create_module)(some_keys)

In this case, some_modules has MyModule as its .__class__, but is not a valid MyModule because its weight and bias have shapes (5, 2, 3) and (5, 2) respectively. Therefore, some_modules._validate() will fail.

Hint

In this case we can annotate it with Vmapped, as it passes the type check for Vmapped[MyModule, "5"].

Return type:

None

Raises:

TypinoxTypeViolation – If the type of the module is not valid.

typinox.field(*, typecheck: bool = True, metadata: dict[str, Any] | None = None, **kwargs) Field[source]#

Specify a field of a typinox typed module.

Parameters:
  • typecheck (bool, default True) – If specified as False, the field will be ignored during typechecking.

  • converter (Callable[[Any], Any], optional) – Used by __init__ to pre-process the value. See equinox.field().

  • static (bool, default False) – If specified as True, the field will be static. This will make it a part of the pytree structure, instead of a subtree or a leaf. See equinox.field().

  • default (Any, optional) – Default value used by __init__ if this field is not provided. See dataclasses.field().

  • default_factory (Callable[[], Any] | Any, optional) – Used to generate a default value by __init__ if this field is not provided. See dataclasses.field().

  • metadata (dict[str, Any] | None, optional) – Metadata to be attached to the field.

class typinox.TypedPolicy(always_validated: bool = True, typecheck_init_result: bool = True, skip_methods: Iterable[str] = frozenset({}))[source]#

Used to configure the typechecking behavior of a module.

As an example, if you want to disable typechecking for a specific method, you can use the following:

class MyModule(TypedModule,
            typed_policy=TypedPolicy(skip_methods={"wtf"})):
    a: int
    b: int = field(typecheck=False)

    def wtf(self, x: int):
        return x

z = MyModule(1, "not an int")
z.wtf("also not an int")

Hint

You can also use the typing.no_type_check() decorator to disable typechecking for a specific method.

Parameters:
  • always_validated (bool, default True) – If True, every argument of methods will be type checked with a custom validator if present.

  • typecheck_init_result (bool, default True) – If False, type checking will not be performed on the result of __init__.

  • skip_methods (frozenset[str], default {}) – Specifies which methods should skip type checking.

class typinox.module.TypedModuleMeta(name, bases, dict_, /, strict: bool | None = False, typed_policy: TypedPolicy | dict | None = None, **kwargs)#

Metaclass for TypedModule.

If you want to create a module with a metaclass other than abc.ABCMeta, equinox._module._ModuleMeta or typinox.module.TypedModuleMeta, you need to create a new metaclass that inherits from this class and your metaclass of choice.