Source code for typinox.module

import dataclasses
from typing import TYPE_CHECKING, dataclass_transform

import equinox
import equinox._module

from ._module import (
    TypedPolicy as TypedPolicy,
    field as field,
)
from .error import TypinoxTypeViolation

if TYPE_CHECKING:
    # Workaround for static type checkers

    type AbstractVar[T] = T | property

    @dataclass_transform(
        field_specifiers=(dataclasses.field, equinox.field, field),
    )
    class TypedModuleMeta(equinox._module._ModuleMeta):  # type: ignore
        pass

    class TypedModule(equinox.Module, metaclass=TypedModuleMeta):
        def __validate_self_str__(self) -> str:
            return ""

[docs] def _validate(self) -> None: ...
else: # The real branch; this is what gets executed at runtime from ._module import RealTypedModuleMeta AbstractVar = equinox.AbstractVar TypedModuleMeta = RealTypedModuleMeta
[docs] class TypedModule(equinox.Module, metaclass=TypedModuleMeta): """ Base class. Inherit this to create a typinox typed module. This class is a subclass of :class:`equinox.Module` and provides automatic type validation for its fields. Thanks to :class:`equinox.Module`, every typed module is automatically a `dataclass`_ and a `pytree`_. .. _dataclass: https://docs.python.org/3/library/dataclasses.html .. _pytree: https://docs.jax.dev/en/latest/pytrees.html **Fields** Declare typed fields at the class level, using the same syntax as for `dataclass`_. The class will automatically become a pytree containing all the fields. It will also automatically validate the types of the fields at runtime. .. code-block:: python from jaxtyping import Array, Float, Key class MyModule(TypedModule): weight: Float[Array, "m n"] bias: Float[Array, "m"] sublayers: list[TypedModule] **Initialization** A default constructor is provided similar to a dataclass. It fills each field with the arguments in order. You can also pass keyword arguments to fill the fields of the same name. For example ``MyModule(w, sublayers=[m1, m2], bias=b)``. Alternatively, you can provide an ``__init__`` method to customize the initialization behavior. .. code-block:: python from jaxtyping import Array, Float, Key class MyModule(TypedModule): weight: Float[Array, "m n"] bias: Float[Array, "m"] sublayers: list[TypedModule] def __init__(self, m: int, n: int, key: Key[Array, ""], sublayers: list[TypedModule] = [], ): self.weight = jax.random.normal(key, (m, n)) self.bias = jax.ones((m,)) self.sublayers = sublayers **Methods** Define methods at the class level just like any other class. Every methods is automatically wrapped with :func:`beartype.beartype` and :func:`jaxtyping.jaxtyped` to perform run-time type checking. .. code-block:: python class MyModule(TypedModule): # ... same as above def __call__(self, x: Float[Array, "n"]) -> Float[Array, "m"]: y = jnp.dot(self.weight, x) + self.bias for layer in self.sublayers: y = layer(y) # if y is not a Float[Array, "m"] at this point, # an error will be raised by beartype return y .. tip:: The method does not have to be named ``__call__()``; it might as well be named ``forward()``. The dunder name ``__call__`` is not specially treated by typinox. """ pass
def _validate(self) -> None: """A helper method to validate the type of the module. This is particularly useful for modules that may be vmapped. The argument to and return value of :func:`jax.vmap` are pytrees of arrays with added dimensions, and thus may be invalid modules. For example: .. code-block:: python def create_module(key: Key[Array, ""]): return MyModule(n=3, m=2, key=key, sublayers=[]) key = jax.random.key(1) some_keys = jax.random.split(key, 5) some_modules = jax.vmap(create_module)(some_keys) In this case, ``some_modules`` has ``MyModule`` as its ``.__class__``, but is not a valid ``MyModule`` because its ``weight`` and ``bias`` have shapes ``(5, 2, 3)`` and ``(5, 2)`` respectively. Therefore, ``some_modules._validate()`` will fail. .. hint:: In this case we can annotate it with :class:`Vmapped`, as it passes the type check for ``Vmapped[MyModule, "5"]``. Returns ------- None Raises ------ TypinoxTypeViolation If the type of the module is not valid. """ __tracebackhide__ = True cls = type(self) for kls in cls.__mro__[-2::-1]: if hasattr(kls, "__validate_self_str__"): validated = kls.__validate_self_str__(self) if validated != "": raise TypinoxTypeViolation( f"the value ({self}) is not a {cls}, as {validated}" ) # Patch the _validate method to the class # This is to avoid TypedModuleMeta to attempt to # type check _validate, leading to an infinite recursion type.__setattr__(TypedModule, "_validate", _validate)