Source code for typinox._module

import dataclasses
import inspect
import warnings
import weakref
from abc import ABCMeta
from types import FunctionType

import beartype
from beartype.door import is_bearable
from beartype.typing import (
    Annotated,
    Any,
    Callable,
    Iterable,
    Never,
    Self,
    Sequence,
    Unpack,
    cast,
    get_args,
    get_origin,
    overload,
)
from equinox import (
    AbstractClassVar,
    AbstractVar,
    field as eqx_field,
)
from equinox._module._module import (
    _has_dataclass_init,
    _ModuleMeta as EqxModuleMeta,
)
from equinox._module._prebuilt import BoundMethod as EqxWrapMethod
from jaxtyping import jaxtyped

from ._vmapped import (
    AbstractVmapped,
    VmappedMeta,
    get_vmapped_origin_or_none,
)
from .debug import (
    TypinoxUnknownFunctionWarning,
    debug_warn,
)
from .error import TypinoxNotImplementedError, TypinoxTypeViolation
from .shaped import ensure_shape as ensure_shape
from .validator import ValidatedT, ValidationFailed, validate_str

AnnotatedAlias = cast(type, type(Annotated[int, ">3"]))
CallableAliasType = type(Callable[[int], float])
GenericAliasType = type(tuple[int, str])
UnpackType = type(Unpack[tuple[int, str]])
UnionType = type(int | float)
UnionGenericAlias = type(Self | None)


@overload
def field(): ...


@overload
def field(
    *,
    typecheck: bool = True,
    converter: Callable[[Any], Any] = ...,
    static: bool = False,
    default: Any = ...,
    default_factory: Callable[[], Any] | Any = ...,
    init: bool = True,
    hash: bool | None = None,
    metadata: dict[str, Any] | None = None,
    kw_only: bool = ...,
): ...


[docs] def field( *, typecheck: bool = True, metadata: dict[str, Any] | None = None, **kwargs, ) -> dataclasses.Field: """Specify a field of a typinox typed module. Parameters ---------- typecheck : bool, default True If specified as False, the field will be ignored during typechecking. converter : Callable[[Any], Any], optional Used by ``__init__`` to pre-process the value. See :func:`equinox.field`. static : bool, default False If specified as True, the field will be static. This will make it a part of the pytree structure, instead of a subtree or a leaf. See :func:`equinox.field`. default : Any, optional Default value used by ``__init__`` if this field is not provided. See :func:`dataclasses.field`. default_factory : Callable[[], Any] | Any, optional Used to generate a default value by ``__init__`` if this field is not provided. See :func:`dataclasses.field`. metadata : dict[str, Any] | None, optional Metadata to be attached to the field. """ if metadata is None: metadata = {} metadata["typecheck"] = typecheck return eqx_field( metadata=metadata, **kwargs, )
[docs] @dataclasses.dataclass(frozen=True) class TypedPolicy: """Used to configure the typechecking behavior of a module. As an example, if you want to disable typechecking for a specific method, you can use the following: .. code-block:: python class MyModule(TypedModule, typed_policy=TypedPolicy(skip_methods={"wtf"})): a: int b: int = field(typecheck=False) def wtf(self, x: int): return x z = MyModule(1, "not an int") z.wtf("also not an int") .. hint:: You can also use the :func:`typing.no_type_check` decorator to disable typechecking for a specific method. Parameters ---------- always_validated : bool, default True If True, every argument of methods will be type checked with a custom validator if present. typecheck_init_result : bool, default True If False, type checking will not be performed on the result of ``__init__``. skip_methods : frozenset[str], default {} Specifies which methods should skip type checking. """ always_validated: bool = dataclasses.field(default=True) typecheck_init_result: bool = dataclasses.field(default=True) skip_methods: frozenset[str] = dataclasses.field(default_factory=frozenset) def __init__( self, always_validated: bool = True, typecheck_init_result: bool = True, skip_methods: Iterable[str] = frozenset(), ): object.__setattr__(self, "always_validated", always_validated) object.__setattr__(self, "typecheck_init_result", typecheck_init_result) object.__setattr__(self, "skip_methods", frozenset(skip_methods))
policy_for_type: weakref.WeakKeyDictionary[type, TypedPolicy] = ( weakref.WeakKeyDictionary() ) def mark_as_typed[T: Callable](fn: T) -> T: if getattr(fn, "__typinox_typed__", False): return fn setattr(fn, "__typinox_typed__", True) return fn def marked_as_typed(fn: Callable) -> bool: return getattr(fn, "__typinox_typed__", False) def decorate_function(fn: Callable) -> Callable: return jaxtyped(fn, typechecker=beartype.beartype) def fold_or(args: Sequence[Any]) -> Any: if len(args) == 0: return Never result = args[0] for arg in args[1:]: result = result | arg return result def sanitize_annotation(annotation: Any, cls: type) -> Any: """Recursively sanitize the annotation. Replaces ``Self`` with the class itself; recurses into ``Union`` and similar types.""" if annotation is Self: return cls if isinstance(annotation, UnionType | UnionGenericAlias): args = get_args(annotation) return fold_or([sanitize_annotation(arg, cls) for arg in args]) if isinstance(annotation, UnpackType): args = get_args(annotation) assert len(args) == 1 inner = args[0] return Unpack[sanitize_annotation(inner, cls)] if isinstance(annotation, GenericAliasType): if isinstance(annotation, CallableAliasType): return annotation origin = get_origin(annotation) args = get_args(annotation) if origin is None: raise TypinoxNotImplementedError( f"Unsupported GenericAlias: {annotation}" ) return origin[tuple(sanitize_annotation(arg, cls) for arg in args)] if isinstance(annotation, VmappedMeta): origin = get_vmapped_origin_or_none(annotation) if origin is Self: return cast(type[AbstractVmapped], annotation).replace_inner(cls) if isinstance(annotation, AnnotatedAlias): origin = getattr(annotation, "__origin__") if origin is Self: return cls return Annotated[ sanitize_annotation(origin, cls), *getattr(annotation, "__metadata__", []), ] return annotation def sanitize_member_annotation(annotation: Any, cls: type) -> Any: if get_origin(annotation) is AbstractVar: annotation_args = get_args(annotation) assert len(annotation_args) == 1 inner_annotation = sanitize_annotation(annotation_args[0], cls) return inner_annotation | property if get_origin(annotation) is AbstractClassVar: raise TypinoxNotImplementedError( "AbstractClassVar is not yet supported by Typinox." ) return sanitize_annotation(annotation, cls) def method_transform_annotations( fn: FunctionType, cls: type, policy: TypedPolicy ) -> FunctionType: """Sanitize all annotations of a method. Also wraps every annotation with ValidatedT[].""" annotations = fn.__annotations__ for key, value in annotations.items(): if isinstance(value, str): warnings.warn( f"Typinox: string annotations are not supported: `{value}` in {fn} of {cls}" ) continue new_annotation = sanitize_annotation(value, cls) if policy.always_validated: new_annotation = ValidatedT[new_annotation] # type: ignore if new_annotation is not value: annotations[key] = new_annotation return fn def is_magic(name: str) -> bool: return name.startswith("__") and name.endswith("__") CallableDescriptor = staticmethod | classmethod | property | EqxWrapMethod SKIP_MAGIC_MODULES = frozenset( [ "builtins", "typing", "dataclasses", "typinox", "typinox.module", "typinox._module", "equinox", "equinox._module", ] ) SKIP_MAGIC_NAMES = frozenset( [ "__validate__", "__validate_str__", "__validate_self_str__", ] ) def skip_magic( name: str, fn: Callable | CallableDescriptor, cls: type, policy: TypedPolicy ) -> bool: if name in SKIP_MAGIC_NAMES: return True if fn.__module__ in SKIP_MAGIC_MODULES: return True return False def decorate_method[T: Callable | CallableDescriptor]( name: str, fn: T, cls: type, policy: TypedPolicy ) -> T: """Decorate a method with the typechecker (jaxtyped and beartype). Recurses into staticmethods, classmethods and properties.""" if name in policy.skip_methods: return fn if getattr(fn, "__no_type_check__", False): return fn if isinstance(fn, staticmethod): actual_method = fn.__func__ return cast( T, staticmethod(decorate_method(name, actual_method, cls, policy)) ) if isinstance(fn, classmethod): actual_method = fn.__func__ return cast( T, classmethod(decorate_method(name, actual_method, cls, policy)) ) if isinstance(fn, property): fget = ( decorate_method(name, fn.fget, cls, policy) if fn.fget is not None else None ) fset = ( decorate_method(name, fn.fset, cls, policy) if fn.fset is not None else None ) fdel = ( decorate_method(name, fn.fdel, cls, policy) if fn.fdel is not None else None ) return cast(T, property(fget, fset, fdel)) # # not needed after equinox 0.13 # if isinstance(fn, EqxWrapMethod): # return cast( # T, EqxWrapMethod(decorate_method(name, fn.method, cls, policy)) # ) if not callable(fn): return fn if not inspect.isfunction(fn): # We can only wrap Python-native functions. debug_warn( f"Typinox: attempting to perform typechecking decoration on unknown object: {fn}", TypinoxUnknownFunctionWarning, ) return cast(T, fn) if marked_as_typed(fn): return cast(T, fn) if is_magic(name): if skip_magic(name, fn, cls, policy): return cast(T, fn) # Main case: pure-python function. pyfunc = cast(FunctionType, fn) pyfunc = method_transform_annotations(pyfunc, cls, policy) decorated = decorate_function(pyfunc) decorated = mark_as_typed(decorated) return cast(T, decorated) class RealTypedModuleMeta(EqxModuleMeta): """Metaclass for TypedModule. If you want to create a module with a metaclass other than :class:`abc.ABCMeta`, :class:`equinox._module._ModuleMeta` or :class:`typinox.module.TypedModuleMeta`, you need to create a new metaclass that inherits from this class and your metaclass of choice. """ def __new__( mcs, name, bases, dict_, /, strict: bool | None = False, typed_policy: TypedPolicy | dict | None = None, **kwargs, ): # [Step 1] Create the Module as normal. cls = super().__new__(mcs, name, bases, dict_, strict=strict, **kwargs) # Assumption: # - Every non-magic normal method is wrapped by Equinox. # - A __init__ method is created, either by the user or by Equinox. # [Step 2] Wrap all methods with the typechecker. # [Step 2.0] Prepare the typechecking policy. if isinstance(typed_policy, dict): typed_policy = TypedPolicy(**typed_policy) if typed_policy is None: typed_policy = TypedPolicy() policy_for_type[cls] = typed_policy # [Step 2.1] Wrap the methods with the typechecker. for key, value in cls.__dict__.items(): if key == "__init__": # We skip __init__ method generated by Equinox. if _has_dataclass_init[cls]: continue decorated_value = decorate_method(key, value, cls, typed_policy) if decorated_value is not value: setattr(cls, key, decorated_value) # [Step 3] Add the validator methods. old_validate = cls.__dict__.get("__validate__", None) old_validate_str = cls.__dict__.get("__validate_str__", None) # [Step 3.1] Recursively validate the fields. # [Step 3.1.0] Prepare the annotations to check. sanitized_annotations = { key: sanitize_member_annotation(value, cls) for key, value in cls.__annotations__.items() } # Exclude the fields that are marked as not typechecking. for field in dataclasses.fields(cls): if not field.metadata.get("typecheck", True): sanitized_annotations.pop(field.name, None) # [Step 3.1 cont'd] Actually validate the fields. def __validate_self_str__(self): __tracebackhide__ = True for member, hint in sanitized_annotations.items(): if member not in self.__dict__: continue value = self.__dict__[member] if not is_bearable(value, hint): return f"its {member} does not match type hint {hint}, got {value}" if old_validate_str is not None: result = old_validate_str(self) if result: return result if old_validate is not None: try: result = old_validate(self) except ValidationFailed as e: return str(e) if result is False: return "it failed its custom validation" return "" def __validate_str__(self): __tracebackhide__ = True with jaxtyped("context"): # type: ignore return __validate_self_str__(self) # Add the methods to the class. __validate_str__.__qualname__ = "__validate_str__" ABCMeta.__setattr__(cls, "__validate_self_str__", __validate_self_str__) ABCMeta.__setattr__(cls, "__validate_str__", __validate_str__) ABCMeta.__setattr__(cls, "__validate__", None) return cls # Creating an instance with MyModule(...) will call this method. def __call__(cls, *args, **kwargs): __tracebackhide__ = True # [Step 1] Create the instance as normal. instance = super().__call__(*args, **kwargs) # [Step 2] Typecheck the instance. policy = policy_for_type[cls] if policy.typecheck_init_result: check_result = validate_str(instance) if check_result: raise TypinoxTypeViolation( f"The instance {instance} of {cls} has failed typechecking, as {check_result}" ) return instance RealTypedModuleMeta.__name__ = "TypedModuleMeta"