Skip to content

Commit

Permalink
🐛 stats: fix rv_discrete sample constructor (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored Jan 14, 2025
2 parents 5a48e4e + 136e09b commit 4b16e52
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
65 changes: 50 additions & 15 deletions scipy-stubs/stats/_distn_infrastructure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from typing_extensions import Self, TypeVar, Unpack, override
import numpy as np
import optype as op
import optype.numpy as onp
import optype.numpy.compat as npc
from scipy._typing import RNG, AnyShape, Falsy, ToRNG, Truthy
from scipy.integrate._typing import QuadOpts as _QuadOpts

Expand All @@ -24,16 +25,15 @@ _RVT = TypeVar("_RVT", bound=rv_generic, default=rv_generic)
_RVT_co = TypeVar("_RVT_co", bound=rv_generic, default=rv_generic, covariant=True)
_CRVT_co = TypeVar("_CRVT_co", bound=rv_continuous, default=rv_continuous, covariant=True)
_DRVT_co = TypeVar("_DRVT_co", bound=rv_discrete, default=rv_discrete, covariant=True)
_XKT_co = TypeVar("_XKT_co", bound=np.number[Any], covariant=True, default=np.number[Any])
_XKT_co = TypeVar("_XKT_co", bound=_CoFloat, covariant=True, default=_CoFloat)
_PKT_co = TypeVar("_PKT_co", bound=_Floating, covariant=True, default=_Floating)

_Tuple2: TypeAlias = tuple[_T, _T]
_Tuple3: TypeAlias = tuple[_T, _T, _T]
_Tuple4: TypeAlias = tuple[_T, _T, _T, _T]

_Integer: TypeAlias = np.integer[Any]
_Floating: TypeAlias = np.float64 | np.float32 | np.float16 # longdouble often results in trouble
_CoFloat: TypeAlias = _Floating | _Integer
_CoFloat: TypeAlias = _Floating | npc.integer

_Bool: TypeAlias = bool | np.bool_
_Int: TypeAlias = int | np.int32 | np.int64
Expand Down Expand Up @@ -841,42 +841,76 @@ class rv_discrete(_rv_mixin, rv_generic):
inc: Final[int]
moment_tol: Final[float]

@overload
def __new__(
cls,
a: onp.ToFloat = 0,
b: onp.ToFloat = ...,
name: str | None = None,
badvalue: _Float | None = None,
moment_tol: _Float = 1e-08,
values: _Tuple2[_ToFloatOrND] | None = None,
moment_tol: _Float = 1e-8,
values: None = None,
inc: int | np.int_ = 1,
longname: str | None = None,
shapes: str | None = None,
seed: ToRNG = None,
) -> Self: ...
def __init__( # pyright: ignore[reportInconsistentConstructor]
# NOTE: The return types of the following overloads is ignored by mypy
@overload
def __new__(
cls,
a: onp.ToFloat,
b: onp.ToFloat,
name: str | None,
badvalue: _Float | None,
moment_tol: _Float,
values: _Tuple2[onp.ToFloatND],
inc: int | np.int_ = 1,
longname: str | None = None,
shapes: str | None = None,
seed: ToRNG = None,
) -> rv_sample: ...
@overload
def __new__(
cls,
a: onp.ToFloat = 0,
b: onp.ToFloat = ...,
name: str | None = None,
badvalue: _Float | None = None,
moment_tol: _Float = 1e-8,
*,
values: _Tuple2[onp.ToFloatND],
inc: int | np.int_ = 1,
longname: str | None = None,
shapes: str | None = None,
seed: ToRNG = None,
) -> rv_sample: ...

#
def __init__(
self,
/,
a: onp.ToFloat = 0,
b: onp.ToFloat = ...,
name: str | None = None,
badvalue: _Float | None = None,
moment_tol: _Float = 1e-08,
values: None = None,
moment_tol: _Float = 1e-8,
# mypy workaround: `values` can only be None
values: _Tuple2[onp.ToFloatND] | None = None,
inc: int | np.int_ = 1,
longname: str | None = None,
shapes: str | None = None,
seed: ToRNG = None,
) -> None: ...

#
# NOTE: Using `@override` on `__call__` or `freeze` causes stubtest to crash (mypy 1.11.1)
@overload
def __call__(self, /) -> rv_discrete_frozen[Self, _Float]: ...
@overload
def __call__(self, /, *args: onp.ToFloat, loc: onp.ToFloat = 0, **kwds: onp.ToFloat) -> rv_discrete_frozen[Self, _Float]: ...
@overload
def __call__(self, /, *args: _ToFloatOrND, loc: _ToFloatOrND = 0, **kwds: _ToFloatOrND) -> rv_discrete_frozen[Self]: ...

#
@overload
def freeze(self, /) -> rv_discrete_frozen[Self, _Float]: ...
Expand Down Expand Up @@ -1028,22 +1062,23 @@ class rv_discrete(_rv_mixin, rv_generic):
**kwds: _ToFloatOrND,
) -> _IntOrND: ...

# undocumented
# returned by `rv_discrete.__new__` if `values` is specified
class rv_sample(rv_discrete, Generic[_XKT_co, _PKT_co]):
xk: onp.Array1D[_XKT_co]
pk: onp.Array1D[_PKT_co]
qvals: onp.Array1D[_PKT_co]

def __init__( # pyright: ignore[reportInconsistentConstructor]
def __init__(
self,
/,
a: onp.ToFloat = 0,
b: onp.ToFloat = ...,
name: str | None = None,
badvalue: float | None = None,
moment_tol: float = 1e-08,
values: tuple[_ToFloatOrND, _ToFloatOrND] | None = None,
inc: int = 1,
badvalue: _Float | None = None,
moment_tol: _Float = 1e-8,
# never None in practice, but required by stubtest
values: _Tuple2[onp.ToFloatND] | None = None,
inc: int | np.int_ = 1,
longname: str | None = None,
shapes: str | None = None,
seed: ToRNG = None,
Expand Down
11 changes: 11 additions & 0 deletions tests/stats/test_rv_sample.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing_extensions import assert_type

import numpy as np
import optype.numpy as onp
from scipy.stats._distn_infrastructure import rv_discrete, rv_sample

xk: onp.Array1D[np.int_]
pk: tuple[float, ...]

# mypy fails because it (still) doesn't support __new__ returning something that isn't `Self`
assert_type(rv_discrete(values=(xk, pk)), rv_sample) # type: ignore[assert-type]

0 comments on commit 4b16e52

Please sign in to comment.