Skip to content

Commit

Permalink
stats: casually invent a way to do overload attributes (it even w…
Browse files Browse the repository at this point in the history
…orks on mypy)
  • Loading branch information
jorenham committed Dec 21, 2024
1 parent 650fc6b commit ea5b6eb
Showing 1 changed file with 48 additions and 10 deletions.
58 changes: 48 additions & 10 deletions scipy-stubs/stats/_distribution_infrastructure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# pyright: reportUnannotatedClassAttribute=false

import abc
from collections.abc import Mapping, Sequence, Set as AbstractSet
from typing import Any, ClassVar, Final, Generic, Literal as L, TypeAlias, overload
from collections.abc import Callable, Mapping, Sequence, Set as AbstractSet
from typing import Any, ClassVar, Final, Generic, Literal as L, Protocol, TypeAlias, overload, type_check_only
from typing_extensions import LiteralString, Never, Self, TypeVar, override

import numpy as np
Expand Down Expand Up @@ -61,6 +61,24 @@ _DrawProportions: TypeAlias = tuple[onp.ToFloat, onp.ToFloat, onp.ToFloat, onp.T
_CDist: TypeAlias = ContinuousDistribution[np.floating[Any], _ShapeT0]
_CDist0: TypeAlias = ContinuousDistribution[_FloatingT, tuple[()]]

@type_check_only
class _ParameterField(Protocol[_FloatingT_co, _ShapeT0_co]):
# This actually works (even on mypy)!
@overload
def __get__(
self: _ParameterField[_FloatingT, tuple[()]],
instance: object,
owner: type | None = None,
/,
) -> _FloatingT: ...
@overload
def __get__(
self: _ParameterField[_FloatingT, _ShapeT1],
instance: object,
owner: type | None = None,
/,
) -> onp.ArrayND[_FloatingT, _ShapeT1]: ...

###

_null: Final[_Null] = ...
Expand Down Expand Up @@ -233,7 +251,7 @@ class ContinuousDistribution(_BaseDistribution[_FloatingT_co, _ShapeT0_co], Gene
def __sub__(self, lshift: onp.ToFloat, /) -> ShiftedScaledDistribution[Self, _FloatingT_co, _ShapeT0_co]: ...
def __mul__(self, scale: onp.ToFloat, /) -> ShiftedScaledDistribution[Self, _FloatingT_co, _ShapeT0_co]: ...
def __truediv__(self, iscale: onp.ToFloat, /) -> ShiftedScaledDistribution[Self, _FloatingT_co, _ShapeT0_co]: ...
def __pow__(self, exp: onp.ToInt, /) -> MonotonicTransformedDistribution[Self, _FloatingT_co, _ShapeT0_co]: ...
def __pow__(self, exp: onp.ToInt, /) -> MonotonicTransformedDistribution[Self, _ShapeT0_co]: ...
__radd__ = __add__
__rsub__ = __sub__
__rmul__ = __mul__
Expand Down Expand Up @@ -284,7 +302,7 @@ class ContinuousDistribution(_BaseDistribution[_FloatingT_co, _ShapeT0_co], Gene
@overload
def llf(self, sample: onp.ToFloat | onp.ToFloatND, /, *, axis: AnyShape | None = -1) -> _Float | onp.ArrayND[_Float]: ...

#
_ElementwiseFunction: TypeAlias = Callable[[onp.ArrayND[np.float64]], onp.ArrayND[_FloatingT]]

# 7 years of asking and >400 upvotes, but still no higher-kinded typing support: https://github.com/python/typing/issues/548
class TransformedDistribution(
Expand All @@ -302,18 +320,38 @@ class TransformedDistribution(
) -> None: ...

class MonotonicTransformedDistribution(
TransformedDistribution[_CDistT_co, _FloatingT_co, _ShapeT0_co],
Generic[_CDistT_co, _FloatingT_co, _ShapeT0_co],
TransformedDistribution[_CDistT_co, np.float64, _ShapeT0_co],
Generic[_CDistT_co, _ShapeT0_co],
):
# TODO(jorenham)
...
_g: Final[_ElementwiseFunction]
_h: Final[_ElementwiseFunction]
_dh: Final[_ElementwiseFunction]
_logdh: Final[_ElementwiseFunction]
_increasing: Final[bool]
_repr_pattern: Final[str]

def __init__(
self: MonotonicTransformedDistribution[_CDist[_ShapeT0], _ShapeT0],
X: _CDistT_co,
/,
*args: Never,
g: _ElementwiseFunction,
h: _ElementwiseFunction,
dh: _ElementwiseFunction,
logdh: _ElementwiseFunction | None = None,
increasing: bool = True,
repr_pattern: str | None = None,
tol: opt.Just[float] | _Null = ...,
validation_policy: _ValidationPolicy = None,
cache_policy: _CachePolicy = None,
) -> None: ...

class TruncatedDistribution(
TransformedDistribution[_CDistT_co, _FloatingT_co, _ShapeT0_co],
Generic[_CDistT_co, _FloatingT_co, _ShapeT0_co],
):
lb: _FloatingT_co | onp.ArrayND[_FloatingT_co, _ShapeT0_co]
ub: _FloatingT_co | onp.ArrayND[_FloatingT_co, _ShapeT0_co]
lb: _ParameterField[_FloatingT_co, _ShapeT0_co]
ub: _ParameterField[_FloatingT_co, _ShapeT0_co]

@overload
def __init__(
Expand Down

0 comments on commit ea5b6eb

Please sign in to comment.