Source code for typinox._vmapped

import functools as ft
from collections.abc import Callable
from typing import Any, NoReturn

from beartype.door import is_bearable
from jax import (
    tree as jt,
    tree_util as jtu,
)
from jaxtyping import Array
from jaxtyping._array_types import (
    _anonymous_dim,
    _check_dims,
    _DimType,
    _FixedDim,
    _NamedDim,
    _SymbolicDim,
)
from jaxtyping._storage import get_shape_memo, set_shape_memo

from .error import TypinoxAnnotationError

_vmapped_dim = object()

VMAPPED_LETTER = "$"


def err_array_name(key_path: jtu.KeyPath) -> str:
    """Generates a friendly name from a key path.

    Parameters
    ----------
    key_path : jax.tree_util.KeyPath
        A PyTree key path.

    Returns
    -------
    str
        A human-readable name for the array.
    """
    if len(key_path) == 0:
        return "Array"
    else:
        return f"Array at key path {key_path}"


def single_check_vmapped(
    arr: Array | Any,
    key_path: jtu.KeyPath,
    dim_left: tuple,
    dim_right: tuple,
    single_memo: dict[str, int],
    arg_memo: dict[str, Any],
) -> Array | str | Any:
    """Check a single array for vmapped axes.

    Parameters
    ----------
    arr : Array | Any
        The array to check.
    key_path : jax.tree_util.KeyPath
        The key path of the array, used for error messages.
    dim_left : tuple[jaxtyping._array_types._AbstractDim]
        The dimension specifiers for the first several axes.
    dim_right : tuple[jaxtyping._array_types._AbstractDim]
        The dimension specifiers for the last several axes.
    single_memo : dict[str, int]
        For jaxtyping internal use.
    arg_memo : dict[str, Any]
        For jaxtyping internal use.

    Returns
    -------
    Array | str | Any
        If the array is valid, returns a slice of the array with the vmapped axes removed.
        If the array is invalid, returns an error message.
        If ``arr`` is not an array at all, return it as is.
    """

    if not hasattr(arr, "shape") or not hasattr(arr, "dtype"):
        # Not an array, ignored
        return arr
    shape = arr.shape
    if len(shape) < len(dim_left) + len(dim_right):
        return (
            err_array_name(key_path)
            + " has fewer dimensions than the vmapped axes"
        )
    left_idx = len(dim_left)
    right_idx = len(shape) - len(dim_right)
    if left_idx != 0:
        check = _check_dims(
            dim_left,  # type: ignore
            shape[:left_idx],
            single_memo,
            arg_memo,
        )
        if check != "":
            return err_array_name(key_path) + ": " + check
    if right_idx != len(shape):
        check = _check_dims(
            dim_right,  # type: ignore
            shape[right_idx:],
            single_memo,
            arg_memo,
        )
        if check != "":
            return err_array_name(key_path) + ": " + check
    # Generates a list of slice objects that:
    #   for the first `left_idx` axes, take the 0th element
    #   for the last `right_idx` axes, take the 0th element
    #   for the rest, take all elements
    # This removed the vmapped axes from the array.
    # caveat: doesn't work for axes with size 0
    item_idx = [
        0 if i < left_idx or i >= right_idx else slice(None)
        for i in range(len(shape))
    ]
    return arr[tuple(item_idx)]


def instancecheck_vmapped(
    inner: type,
    dims: tuple[tuple, tuple],
    checker: Callable[[Any, type], bool],
    obj,
    single_memo: dict[str, int],
    arg_memo: dict[str, Any],
) -> str:
    """Check a PyTree structure for vmapped axes.

    Parameters
    ----------
    inner : type
        PyTree structure to check against.
    dims : tuple of two tuples of jaxtyping._array_types._AbstractDim
        The dimension specifiers for the first several axes and the last several axes.
    checker : T -> * -> bool
        An `isinstance`-like function to check the PyTree structure.
    obj : T
        The object to check.
    single_memo : dict[str, int]
        For jaxtyping internal use.
    arg_memo : dict[str, Any]
        For jaxtyping internal use.

    Returns
    -------
    str
        `""` if the object passes the check, or an error message otherwise.
    """

    dim_left, dim_right = dims
    nodedefs, treedef = jt.flatten_with_path(obj)
    leaves: list[Any] = []
    for key_path, arr in nodedefs:
        # the case where `arr` is a string needs to be handled separately...
        if isinstance(arr, str):
            leaves.append(arr)
            continue
        check = single_check_vmapped(
            arr,
            key_path,
            dim_left,
            dim_right,
            single_memo,
            arg_memo,
        )
        # ... as `check` is an str if the single array check fails
        if isinstance(check, str):
            if check != "":
                return check
            check = arr
        leaves.append(check)

    # when all single arrays are valid,
    # map the leaves (without the vmapped axes) back to the PyTree structure
    new_obj = jt.unflatten(treedef, leaves)
    if hasattr(inner, "__instancecheck_str__"):
        return inner.__instancecheck_str__(new_obj)
    if checker(new_obj, inner):
        return ""
    return f"{new_obj} is not an instance of {inner}"


class VmappedMeta(type):
    def __instancecheck__(cls, obj):
        return cls.__instancecheck_str__(obj) == ""

    def __instancecheck_str__(cls, obj) -> str:
        if (
            not hasattr(cls, "inner")
            or not hasattr(cls, "dims")
            or not hasattr(cls, "checker")
        ):
            raise TypinoxAnnotationError(
                "Invalid `Vmapped` class; must have `inner`, `dims` and `checker` attributes."
            )
        single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo()
        single_memo_bak = single_memo.copy()
        variadic_memo_bak = variadic_memo.copy()
        pytree_memo_bak = pytree_memo.copy()
        arg_memo_bak = arg_memo.copy()

        try:
            check = instancecheck_vmapped(
                cls.inner,  # type: ignore
                cls.dims,  # type: ignore
                cls.checker,  # type: ignore
                obj,
                single_memo,
                arg_memo,
            )
        except Exception as e:
            set_shape_memo(
                single_memo_bak,
                variadic_memo_bak,
                pytree_memo_bak,
                arg_memo_bak,
            )
            raise e
        if check == "":
            return check
        else:
            set_shape_memo(
                single_memo_bak,
                variadic_memo_bak,
                pytree_memo_bak,
                arg_memo_bak,
            )
            return check

    def __repr__(cls):
        return cls.__module__ + "." + cls.__qualname__

    def __str__(cls):
        return cls.__module__ + "." + cls.__qualname__


class AbstractVmapped(metaclass=VmappedMeta):
    inner: type
    dim_str: str
    dims: tuple[tuple, tuple]
    checker: Callable
    base_name: str

    @classmethod
    def replace_inner(cls, inner):
        return create_vmapped_class(
            cls.base_name, inner, cls.dim_str, cls.dims, cls.checker
        )


@ft.cache
def create_vmapped_class(base_name, inner, dim_str, dims, checker):
    name = f"{base_name}[{inner}, {dim_str}]"
    cls = VmappedMeta(
        name,
        (AbstractVmapped,),
        dict(
            inner=inner,
            dim_str=dim_str,
            dims=dims,
            checker=checker,
            base_name=base_name,
        ),
    )
    cls.__module__ = "typinox"
    return cls


def parse_dims(dim_str: str):
    dims = []
    for index, elem in enumerate(dim_str.split()):
        anonymous = False
        vmapped = False
        if "," in elem and "(" not in elem:
            # Common mistake.
            # Disable in the case that there's brackets to allow for function calls,
            # e.g. `min(foo,bar)`, in symbolic axes.
            raise TypinoxAnnotationError(
                "Axes should be separated with spaces, not commas"
            )
        if elem == "...":
            raise TypinoxAnnotationError(
                "jaxtyping multiple axis not supported in Vmapped; "
                "`...` is not allowed"
            )
        while True:
            if len(elem) == 0:
                # This branch needed as just `_` or `$` is valid
                break
            first_char = elem[0]
            if first_char == "#":
                raise TypinoxAnnotationError(
                    "jaxtyping broadcastable annotation is unsupported in Vmapped; "
                    "`#foo` is not allowed"
                )
            elif first_char == "*":
                raise TypinoxAnnotationError(
                    "jaxtyping multiple axis annotation is unsupported in Vmapped; "
                    "`*foo` is not allowed"
                )
            elif first_char == "_":
                if anonymous:
                    raise TypinoxAnnotationError(
                        "Do not use _ twice to denote anonymity, e.g. `__foo` "
                        "is not allowed"
                    )
                anonymous = True
                elem = elem[1:]
            elif first_char == "?":
                raise TypinoxAnnotationError(
                    "jaxtyping treepath-dependent annotation is unsupported in Vmapped; "
                    "`?foo` is not allowed"
                )
            elif first_char == VMAPPED_LETTER:
                if vmapped:
                    raise TypinoxAnnotationError(
                        "Do not use "
                        + VMAPPED_LETTER
                        + " twice to denote vmapped axes, e.g. `"
                        + VMAPPED_LETTER
                        + VMAPPED_LETTER
                        + "` is not allowed"
                    )
                vmapped = True
                elem = elem[1:]
            elif elem.count("=") == 1:
                _, elem = elem.split("=")
            else:
                break
        elem_size = 0
        if len(elem) == 0 or elem.isidentifier():
            dim_type = _DimType.named
        else:
            try:
                elem_size = int(elem)
            except ValueError:
                dim_type = _DimType.symbolic
            else:
                dim_type = _DimType.fixed

        out: _FixedDim | _NamedDim | _SymbolicDim | object
        if dim_type is _DimType.fixed:
            if anonymous:
                raise TypinoxAnnotationError(
                    "Cannot have a fixed axis be anonymous, e.g. `_4` is not allowed."
                )
            if vmapped:
                raise TypinoxAnnotationError(
                    "Cannot have a fixed axis be vmapped, e.g. `"
                    + VMAPPED_LETTER
                    + "4` is not allowed."
                )
            out = _FixedDim(elem_size, False)
        elif dim_type is _DimType.named:
            if vmapped:
                out = _vmapped_dim
            elif anonymous:
                out = _anonymous_dim
            else:
                out = _NamedDim(elem, False, False)
        else:
            assert dim_type is _DimType.symbolic
            if anonymous:
                raise TypinoxAnnotationError(
                    "Cannot have a symbolic axis be anonymous, "
                    "e.g. `_foo+bar` is not allowed."
                )
            if vmapped:
                raise TypinoxAnnotationError(
                    "Cannot have a symbolic axis be vmapped, "
                    "e.g. `" + VMAPPED_LETTER + "foo+bar` is not allowed."
                )
            out = _SymbolicDim(elem, False)
        dims.append(out)
    n_vmapped_dims = sum(1 for dim in dims if dim is _vmapped_dim)
    if n_vmapped_dims > 1:
        raise TypinoxAnnotationError(
            "Only one axis can be marked as vmapped, e.g. `n "
            + VMAPPED_LETTER
            + " m p` is allowed, but `n "
            + VMAPPED_LETTER
            + " m "
            + VMAPPED_LETTER
            + " p` is not."
        )
    if n_vmapped_dims == 0:
        dims.append(_vmapped_dim)
    if len(dims) == 1:
        return None
    idx = dims.index(_vmapped_dim)
    left, right = dims[:idx], dims[idx + 1 :]
    return tuple(left), tuple(right)


@ft.cache
def make_vmapped(name, inner, dim_str, checker):
    dims = parse_dims(dim_str)
    if dims is None:
        return inner
    dim_left, dim_right = dims
    if (
        inner is not None
        and isinstance(inner, type)
        and issubclass(inner, AbstractVmapped)
    ):
        old_inner, (old_left, old_right) = inner.inner, inner.dims
        dim_left = dim_left + old_left
        dim_right = old_right + dim_right
        inner = old_inner

    return create_vmapped_class(
        name, inner, dim_str, (dim_left, dim_right), checker
    )


def get_vmapped_origin_or_none(cls):
    if not isinstance(cls, type):
        return None
    if not issubclass(cls, AbstractVmapped):
        return None
    return cls.inner


class VmappedHelperMeta(type):
    def __instancecheck__(cls, obj) -> NoReturn:
        raise TypinoxAnnotationError(
            "Do not use `Vmapped` as a type hint without specifying"
            " the PyTree structure and the vmapped dimension."
        )

    def __getitem__(cls, params):
        if not isinstance(params, tuple) or len(params) != 2:
            raise TypinoxAnnotationError(
                "Vmapped type hint must be a tuple of a PyTree structure and a string."
            )
        checker = getattr(cls, "checker", isinstance)
        inner, dim_str = params
        dim_str = dim_str.strip()
        return make_vmapped(cls.__qualname__, inner, dim_str, checker)


[docs] class VmappedT(metaclass=VmappedHelperMeta): """Vmapped type hint. When ``VmappedT[inner, shape]`` is used as a type hint, the object is expected to be a PyTree structure specified by ``inner``, with each array of it extended by ``shape``. For example, ``VmappedT[Linear, "device batch"]`` is a type hint for a ``Linear`` PyTree with each array shape prepended by two dimensions, named ``device`` and ``batch``. **Shape** The ``shape`` is a space-separated string of dimension specifiers, such as ``"a 3"``. Each dimension can be one of the following: - A named dimension, such as ``a``. - A fixed dimension, such as ``3``. - A symbolic expression in terms of other variable-size axes, such as ``a-1``. - ``_``: An anonymous dimension, which is a named dimension that is not matched to any other dimension. They function similarly to the dimension specifiers in :mod:`jaxtyping`. It can also be: - ``$``: Represents the original array shape before vmapping. This is optional, and if omitted, it is assumed that ``$`` appears in the end. For example, if ``T`` is ``Float[Array, "3 4"]``, then ``Vmapped[T, "batch"]`` and ``Vmapped[T, "batch $"]`` are equivalent to ``Float[Array, "batch 3 4"]``; ``Vmapped[T, "$ batch"]`` is equivalent to ``Float[Array, "3 4 batch"]``. Other jaxtyping array annotation modifiers, such as ``*``, ``?``, and ``...`` are not supported. .. warning:: None of the axes specified in ``shape`` can be zero (e.g., ``Vmapped[T, "0 3"]`` is invalid). **Inner type** The ``inner`` is the type hint for a PyTree structure of arrays. It can be any valid type hint, including another ``VmappedT`` type hint. It is most useful when ``isinstance(x, inner)`` also checks the shape of arrays in ``x``. Examples of such ``inner`` hints include: - **(jax or numpy) arrays**: ``isinstance(x, VmappedT[Array, "b c"])`` checks that ``x`` is a jax array with at least two initial dimensions. - **tuple of arrays**: ``isinstance(x, VmappedT[tuple[Array, Array], "b c"])`` checks that both elements of ``x`` are jax arrays with at least two initial dimensions, and that their first two dimensions match. - **jaxtyping shaped arrays**: ``VmappedT[Float[Array, "b c"], "a"]`` is equivalent to ``VmappedT[Array, "a b c"]``, while ``VmappedT[Float[Array, "b c"], "$ a"]`` is equivalent to ``VmappedT[Array, "b c a"]``. - **Typinox Modules**: ``isinstance(x, VmappedT[M, "b c"])`` checks that ``x`` is an instance of a :class:`TypedModule <typinox.module.TypedModule>` class ``M``, and each array member of ``x`` has at least two initial dimensions. If these dimensions are removed, the resulting object should match ``M``. """ def __init__(self) -> NoReturn: raise TypinoxAnnotationError( "Do not instantiate `VmappedT` directly; use it as a type hint." ) checker = is_bearable
[docs] class VmappedI(metaclass=VmappedHelperMeta): """Vmapped type hint without using :mod:`beartype`. Used only when you do not want to use :func:`is_bearable() <beartype.door.is_bearable>` to check the type of the object. Useful when the inner annotation is a :mod:`jaxtyping` array type hint. The interface is the same as :class:`VmappedT <typinox.vmapped.VmappedT>`. """ def __init__(self) -> NoReturn: raise TypinoxAnnotationError( "Do not instantiate `VmappedI` directly; use it as a type hint." ) checker = isinstance