diff --git a/scipy-stubs/stats/_distn_infrastructure.pyi b/scipy-stubs/stats/_distn_infrastructure.pyi index 97eb73a4..07390680 100644 --- a/scipy-stubs/stats/_distn_infrastructure.pyi +++ b/scipy-stubs/stats/_distn_infrastructure.pyi @@ -11,6 +11,7 @@ from typing_extensions import Self, TypeVar, Unpack, override import numpy as np import optype as op import optype.numpy as onp +import optype.numpy.compat as npc from scipy._typing import RNG, AnyShape, Falsy, ToRNG, Truthy from scipy.integrate._typing import QuadOpts as _QuadOpts @@ -24,16 +25,15 @@ _RVT = TypeVar("_RVT", bound=rv_generic, default=rv_generic) _RVT_co = TypeVar("_RVT_co", bound=rv_generic, default=rv_generic, covariant=True) _CRVT_co = TypeVar("_CRVT_co", bound=rv_continuous, default=rv_continuous, covariant=True) _DRVT_co = TypeVar("_DRVT_co", bound=rv_discrete, default=rv_discrete, covariant=True) -_XKT_co = TypeVar("_XKT_co", bound=np.number[Any], covariant=True, default=np.number[Any]) +_XKT_co = TypeVar("_XKT_co", bound=_CoFloat, covariant=True, default=_CoFloat) _PKT_co = TypeVar("_PKT_co", bound=_Floating, covariant=True, default=_Floating) _Tuple2: TypeAlias = tuple[_T, _T] _Tuple3: TypeAlias = tuple[_T, _T, _T] _Tuple4: TypeAlias = tuple[_T, _T, _T, _T] -_Integer: TypeAlias = np.integer[Any] _Floating: TypeAlias = np.float64 | np.float32 | np.float16 # longdouble often results in trouble -_CoFloat: TypeAlias = _Floating | _Integer +_CoFloat: TypeAlias = _Floating | npc.integer _Bool: TypeAlias = bool | np.bool_ _Int: TypeAlias = int | np.int32 | np.int64 @@ -841,35 +841,68 @@ class rv_discrete(_rv_mixin, rv_generic): inc: Final[int] moment_tol: Final[float] + @overload def __new__( cls, a: onp.ToFloat = 0, b: onp.ToFloat = ..., name: str | None = None, badvalue: _Float | None = None, - moment_tol: _Float = 1e-08, - values: _Tuple2[_ToFloatOrND] | None = None, + moment_tol: _Float = 1e-8, + values: None = None, inc: int | np.int_ = 1, longname: str | None = None, shapes: str | None = None, seed: ToRNG = None, ) -> Self: ... - def __init__( # pyright: ignore[reportInconsistentConstructor] + # NOTE: The return types of the following overloads is ignored by mypy + @overload + def __new__( + cls, + a: onp.ToFloat, + b: onp.ToFloat, + name: str | None, + badvalue: _Float | None, + moment_tol: _Float, + values: _Tuple2[onp.ToFloatND], + inc: int | np.int_ = 1, + longname: str | None = None, + shapes: str | None = None, + seed: ToRNG = None, + ) -> rv_sample: ... + @overload + def __new__( + cls, + a: onp.ToFloat = 0, + b: onp.ToFloat = ..., + name: str | None = None, + badvalue: _Float | None = None, + moment_tol: _Float = 1e-8, + *, + values: _Tuple2[onp.ToFloatND], + inc: int | np.int_ = 1, + longname: str | None = None, + shapes: str | None = None, + seed: ToRNG = None, + ) -> rv_sample: ... + + # + def __init__( self, /, a: onp.ToFloat = 0, b: onp.ToFloat = ..., name: str | None = None, badvalue: _Float | None = None, - moment_tol: _Float = 1e-08, - values: None = None, + moment_tol: _Float = 1e-8, + # mypy workaround: `values` can only be None + values: _Tuple2[onp.ToFloatND] | None = None, inc: int | np.int_ = 1, longname: str | None = None, shapes: str | None = None, seed: ToRNG = None, ) -> None: ... - # # NOTE: Using `@override` on `__call__` or `freeze` causes stubtest to crash (mypy 1.11.1) @overload def __call__(self, /) -> rv_discrete_frozen[Self, _Float]: ... @@ -877,6 +910,7 @@ class rv_discrete(_rv_mixin, rv_generic): def __call__(self, /, *args: onp.ToFloat, loc: onp.ToFloat = 0, **kwds: onp.ToFloat) -> rv_discrete_frozen[Self, _Float]: ... @overload def __call__(self, /, *args: _ToFloatOrND, loc: _ToFloatOrND = 0, **kwds: _ToFloatOrND) -> rv_discrete_frozen[Self]: ... + # @overload def freeze(self, /) -> rv_discrete_frozen[Self, _Float]: ... @@ -1028,22 +1062,23 @@ class rv_discrete(_rv_mixin, rv_generic): **kwds: _ToFloatOrND, ) -> _IntOrND: ... -# undocumented +# returned by `rv_discrete.__new__` if `values` is specified class rv_sample(rv_discrete, Generic[_XKT_co, _PKT_co]): xk: onp.Array1D[_XKT_co] pk: onp.Array1D[_PKT_co] qvals: onp.Array1D[_PKT_co] - def __init__( # pyright: ignore[reportInconsistentConstructor] + def __init__( self, /, a: onp.ToFloat = 0, b: onp.ToFloat = ..., name: str | None = None, - badvalue: float | None = None, - moment_tol: float = 1e-08, - values: tuple[_ToFloatOrND, _ToFloatOrND] | None = None, - inc: int = 1, + badvalue: _Float | None = None, + moment_tol: _Float = 1e-8, + # never None in practice, but required by stubtest + values: _Tuple2[onp.ToFloatND] | None = None, + inc: int | np.int_ = 1, longname: str | None = None, shapes: str | None = None, seed: ToRNG = None, diff --git a/tests/stats/test_rv_sample.pyi b/tests/stats/test_rv_sample.pyi new file mode 100644 index 00000000..201dfd39 --- /dev/null +++ b/tests/stats/test_rv_sample.pyi @@ -0,0 +1,11 @@ +from typing_extensions import assert_type + +import numpy as np +import optype.numpy as onp +from scipy.stats._distn_infrastructure import rv_discrete, rv_sample + +xk: onp.Array1D[np.int_] +pk: tuple[float, ...] + +# mypy fails because it (still) doesn't support __new__ returning something that isn't `Self` +assert_type(rv_discrete(values=(xk, pk)), rv_sample) # type: ignore[assert-type]