Typed Modules#
- class typinox.TypedModule[source]#
Base class. Inherit this to create a typinox typed module.
This class is a subclass of
equinox.Moduleand provides automatic type validation for its fields.Thanks to
equinox.Module, every typed module is automatically a dataclass and a pytree.Fields
Declare typed fields at the class level, using the same syntax as for dataclass. The class will automatically become a pytree containing all the fields. It will also automatically validate the types of the fields at runtime.
from jaxtyping import Array, Float, Key class MyModule(TypedModule): weight: Float[Array, "m n"] bias: Float[Array, "m"] sublayers: list[TypedModule]
Initialization
A default constructor is provided similar to a dataclass. It fills each field with the arguments in order. You can also pass keyword arguments to fill the fields of the same name. For example
MyModule(w, sublayers=[m1, m2], bias=b).Alternatively, you can provide an
__init__method to customize the initialization behavior.from jaxtyping import Array, Float, Key class MyModule(TypedModule): weight: Float[Array, "m n"] bias: Float[Array, "m"] sublayers: list[TypedModule] def __init__(self, m: int, n: int, key: Key[Array, ""], sublayers: list[TypedModule] = [], ): self.weight = jax.random.normal(key, (m, n)) self.bias = jax.ones((m,)) self.sublayers = sublayers
Methods
Define methods at the class level just like any other class. Every methods is automatically wrapped with
beartype.beartype()andjaxtyping.jaxtyped()to perform run-time type checking.class MyModule(TypedModule): # ... same as above def __call__(self, x: Float[Array, "n"]) -> Float[Array, "m"]: y = jnp.dot(self.weight, x) + self.bias for layer in self.sublayers: y = layer(y) # if y is not a Float[Array, "m"] at this point, # an error will be raised by beartype return y
Tip
The method does not have to be named
__call__(); it might as well be namedforward(). The dunder name__call__is not specially treated by typinox.- _validate() None[source]#
A helper method to validate the type of the module.
This is particularly useful for modules that may be vmapped. The argument to and return value of
jax.vmap()are pytrees of arrays with added dimensions, and thus may be invalid modules. For example:def create_module(key: Key[Array, ""]): return MyModule(n=3, m=2, key=key, sublayers=[]) key = jax.random.key(1) some_keys = jax.random.split(key, 5) some_modules = jax.vmap(create_module)(some_keys)
In this case,
some_moduleshasMyModuleas its.__class__, but is not a validMyModulebecause itsweightandbiashave shapes(5, 2, 3)and(5, 2)respectively. Therefore,some_modules._validate()will fail.Hint
In this case we can annotate it with
Vmapped, as it passes the type check forVmapped[MyModule, "5"].- Return type:
None
- Raises:
TypinoxTypeViolation – If the type of the module is not valid.
- typinox.field(*, typecheck: bool = True, metadata: dict[str, Any] | None = None, **kwargs) Field[source]#
Specify a field of a typinox typed module.
- Parameters:
typecheck (bool, default True) – If specified as False, the field will be ignored during typechecking.
converter (Callable[[Any], Any], optional) – Used by
__init__to pre-process the value. Seeequinox.field().static (bool, default False) – If specified as True, the field will be static. This will make it a part of the pytree structure, instead of a subtree or a leaf. See
equinox.field().default (Any, optional) – Default value used by
__init__if this field is not provided. Seedataclasses.field().default_factory (Callable[[], Any] | Any, optional) – Used to generate a default value by
__init__if this field is not provided. Seedataclasses.field().metadata (dict[str, Any] | None, optional) – Metadata to be attached to the field.
- class typinox.TypedPolicy(always_validated: bool = True, typecheck_init_result: bool = True, skip_methods: Iterable[str] = frozenset({}))[source]#
Used to configure the typechecking behavior of a module.
As an example, if you want to disable typechecking for a specific method, you can use the following:
class MyModule(TypedModule, typed_policy=TypedPolicy(skip_methods={"wtf"})): a: int b: int = field(typecheck=False) def wtf(self, x: int): return x z = MyModule(1, "not an int") z.wtf("also not an int")
Hint
You can also use the
typing.no_type_check()decorator to disable typechecking for a specific method.- Parameters:
always_validated (bool, default True) – If True, every argument of methods will be type checked with a custom validator if present.
typecheck_init_result (bool, default True) – If False, type checking will not be performed on the result of
__init__.skip_methods (frozenset[str], default {}) – Specifies which methods should skip type checking.
- class typinox.module.TypedModuleMeta(name, bases, dict_, /, strict: bool | None = False, typed_policy: TypedPolicy | dict | None = None, **kwargs)#
Metaclass for TypedModule.
If you want to create a module with a metaclass other than
abc.ABCMeta,equinox._module._ModuleMetaortypinox.module.TypedModuleMeta, you need to create a new metaclass that inherits from this class and your metaclass of choice.