Custom Type Validator#
Typinox supports custom type validators. It is useful to check the internal consistency of an object.
They are used in a similar way to
__check_init__() methods in equinox.Modules.
Since PyTree operations does not call the __check_init__ method,
the result of an operation like jax.vmap() or equinox.filter() may
be invalid, and it is dangerous to call methods on it.
For a TypedModule,
the __validate__() method is called when another method requires to validate
the module itself. For example, with the following module:
from typing import Self
import jax.numpy as jnp
from jaxtyping import Float, Array
import typinox as tpx
from typinox import ValidationFailed
class SquareMat(tpx.TypedModule):
mat: Float[Array, "n n"]
n: int = tpx.field(static=True)
def __validate__(self):
if self.n != self.mat.shape[0]:
raise ValidationFailed(f"n={self.n} does not match mat.shape[0]={self.mat.shape[0]}")
def diagonal_plus_one(self: Self):
return jnp.diagonal(self.mat) + 1
With the explicit Self annotation,
the diagonal_plus_one method will validate self before executing.
f = SquareMat(mat=jnp.eye(3), n=3)
f.diagonal_plus_one() # works
_ = SquareMat(mat=jnp.eye(3), n=4) # fails
f = jax.tree.map(lambda x: x.reshape((9, 1)), f)
# here f.mat has shape (9, 1) and f.n is 3
f.diagonal_plus_one() # fails
There are three ways to define a custom validator:
- method named
__validate__()returningbool. If it returns
False, the validation fails.
- method named
- method named
__validate__()returningNone. It raise a
ValidationFailedexception if the validation fails.
- method named
- method named
__validate_str__()returningstr. If it returns an empty string, the validation passes. The returned non-empty string will be used as the error message.
- method named
Note
When both methods are present, both needs to pass for the validation to succeed. This is not recommended due to the potential for confusion.
When type-checking an object, each of its base classes are checked in order. All of them need to pass for the object to be considered valid.
Validating outside the class scope#
If you want to validate an object outside the class scope,
you can use the ValidateT type annotation.
In the previous example, the value of f is invalid after the reshape operation.
Therefore,
from beartype.door import is_bearable
is_bearable(f, SquareMat) # True
is_bearable(f, ValidateT[SquareMat]) # False