jaxtyping v0.2.32
-
The array type can now be either
Any
or aTypeVar
. In both cases this means that anything is allowed at runtime. As usual, static type checkers will only look at the array part of an annotation, so that an annotation of the formFloat[T, "foo bar"]
(whereT = TypeVar("T")
) will be treated as justT
by static type checkers. This allows for expressing array-type-polymorphism with static typechecking. Here's an example:import numpy as np import torch from typing import TypeVar TensorLike = TypeVar("TensorLike", np.ndarray, torch.Tensor) def stack_scalars(x: Float[TensorLike, ""], y: Float[TensorLike, ""]) -> Float[TensorLike, "2"]: if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): return np.stack([x, y]) elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): return torch.stack([x, y]) else: raise ValueError("Invalid array types!")
-
Fixed a bug in which the very first argument to a function was erroneously reported as the one at fault for a typechecking error. This bug occurred when using default arguments.
Full Changelog: v0.2.31...v0.2.32