Source code for typinox.shaped

import dataclasses
from typing import Any, Protocol, overload


@dataclasses.dataclass(frozen=True)
class _FakeArray:
    shape: tuple[int, ...]
    dtype: Any = dataclasses.field(default=None)


[docs] class NDArray(Protocol): """ Duck match for any object that behaves like a numpy array (has ``.shape`` and ``.dtype`` attributes). """ @property def shape(self) -> tuple[int, ...]: ... @property def dtype(self) -> Any: ...
ShapeLike = int | tuple[int, ...] | NDArray def shape_sanitize(shape: ShapeLike) -> tuple[int, ...]: if isinstance(shape, int): return (shape,) if isinstance(shape, tuple): return shape return shape.shape @overload def ensure_shape(shape: ShapeLike, dim_spec: str, /): ... @overload def ensure_shape(name: str, shape: ShapeLike, dim_spec: str, /): ...
[docs] def ensure_shape(*args): """ Ensure that the shape of an array matches :mod:`jaxtyping` named dimensions. Parameters ---------- name: str, optional The name of the array. shape : ShapeLike The shape of the array. dim_spec : str The :mod:`jaxtyping` named dimensions to be matched. Returns ------- None If the shape matches the named dimensions, return None. Raises ------ ValidationFailed when the shape does not match the named dimensions """ from jaxtyping import Shaped from .validator import ValidationFailed if len(args) == 2: shape, dim_spec = args name = "" elif len(args) == 3: name, shape, dim_spec = args else: raise ValueError(f"invalid number of arguments: {len(args)}") obj = _FakeArray(shape_sanitize(shape)) if not isinstance(obj, Shaped[_FakeArray, dim_spec]): # type: ignore if name: raise ValidationFailed( f'{name} has shape {shape} which does not match the named dimensions "{dim_spec}"' ) else: raise ValidationFailed( f"shape {shape} does not match the named dimensions {dim_spec}" )
@overload def ensure_shape_equal(shape1: ShapeLike, shape2: ShapeLike, /): ... @overload def ensure_shape_equal(name: str, shape1: ShapeLike, shape2: ShapeLike, /): ... @overload def ensure_shape_equal( name1: str, shape1: ShapeLike, name2: str, shape2: ShapeLike, / ): ...
[docs] def ensure_shape_equal(*args): """ Ensure that the shapes of two arrays are equal. Parameters ---------- name : str, optional The name of the arrays. Cannot be provided if name1 or name2 is provided. name1 : str, optional The name of the first array. Must be provided if name2 is provided. shape1 : int, or tuple[int, ...], or an object with ``.shape`` The shape of the first array. name2 : str, optional The name of the second array. Must be provided if name1 is provided. shape2 : int, or tuple[int, ...], or an object with ``.shape`` The shape of the second array. Returns ------- None If the shapes are equal, return None. Raises ------ ValidationFailed when the shapes are not equal. """ from .validator import ValidationFailed if len(args) == 2: shape1, shape2 = args name1 = name2 = "" elif len(args) == 3: name, shape1, shape2 = args name1 = name2 = name elif len(args) == 4: name1, shape1, name2, shape2 = args else: raise ValueError(f"invalid number of arguments: {len(args)}") if shape_sanitize(shape1) != shape_sanitize(shape2): raise ValidationFailed( f"{name1} {shape1} does not match {name2} {shape2}" )