Skip to content

jaxtyping v0.2.32

Compare
Choose a tag to compare
@github-actions github-actions released this 12 Jul 09:44
· 31 commits to main since this release
  • The array type can now be either Any or a TypeVar. In both cases this means that anything is allowed at runtime. As usual, static type checkers will only look at the array part of an annotation, so that an annotation of the form Float[T, "foo bar"] (where T = TypeVar("T")) will be treated as just T by static type checkers. This allows for expressing array-type-polymorphism with static typechecking. Here's an example:

    import numpy as np
    import torch
    from typing import TypeVar
    
    TensorLike = TypeVar("TensorLike", np.ndarray, torch.Tensor)
    
    def stack_scalars(x: Float[TensorLike, ""], y: Float[TensorLike, ""]) -> Float[TensorLike, "2"]:
        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
            return np.stack([x, y])
        elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
            return torch.stack([x, y])
        else:
            raise ValueError("Invalid array types!")
  • Fixed a bug in which the very first argument to a function was erroneously reported as the one at fault for a typechecking error. This bug occurred when using default arguments.

Full Changelog: v0.2.31...v0.2.32