diff --git a/scipy-stubs/special/_logsumexp.pyi b/scipy-stubs/special/_logsumexp.pyi index 71f1748a..c0f45318 100644 --- a/scipy-stubs/special/_logsumexp.pyi +++ b/scipy-stubs/special/_logsumexp.pyi @@ -1,50 +1,98 @@ -from typing import Literal, overload +from typing import overload import numpy as np import optype.numpy as onp +from scipy._typing import AnyShape, Falsy, Truthy __all__ = ["log_softmax", "logsumexp", "softmax"] -# TODO: Support `return_sign=True` @overload def logsumexp( a: onp.ToFloat, - axis: int | tuple[int, ...] | None = None, + axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, - return_sign: Literal[False, 0] = False, + return_sign: Falsy = False, ) -> np.float64: ... @overload def logsumexp( a: onp.ToComplex, - axis: int | tuple[int, ...] | None = None, + axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, - return_sign: Literal[False, 0] = False, + return_sign: Falsy = False, ) -> np.float64 | np.complex128: ... @overload def logsumexp( a: onp.ToFloatND, - axis: int | tuple[int, ...], + axis: AnyShape, b: onp.ToFloat | onp.ToFloatND | None = None, keepdims: bool = False, - return_sign: Literal[False, 0] = False, + return_sign: Falsy = False, ) -> np.float64 | onp.ArrayND[np.float64]: ... @overload def logsumexp( a: onp.ToComplexND, - axis: int | tuple[int, ...], + axis: AnyShape, b: onp.ToFloat | onp.ToFloatND | None = None, keepdims: bool = False, - return_sign: Literal[False, 0] = False, + return_sign: Falsy = False, ) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... +@overload +def logsumexp( + a: onp.ToFloat, + axis: AnyShape | None = None, + b: onp.ToFloat | None = None, + keepdims: bool = False, + *, + return_sign: Truthy, +) -> tuple[np.float64, bool | np.bool_]: ... +@overload +def logsumexp( + a: onp.ToComplex, + axis: AnyShape | None = None, + b: onp.ToFloat | None = None, + keepdims: bool = False, + *, + return_sign: Truthy, +) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ... +@overload +def logsumexp( + a: onp.ToFloatND, + axis: AnyShape, + b: onp.ToFloat | onp.ToFloatND | None = None, + keepdims: bool = False, + *, + return_sign: Truthy, +) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ... +@overload +def logsumexp( + a: onp.ToComplexND, + axis: AnyShape, + b: onp.ToFloat | onp.ToFloatND | None = None, + keepdims: bool = False, + *, + return_sign: Truthy, +) -> ( + tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]] +): ... -# TODO: Overload real/complex and scalar/array -def softmax( - x: onp.ToComplex | onp.ToComplexND, - axis: int | tuple[int, ...] | None = None, -) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... -def log_softmax( - x: onp.ToComplex | onp.ToComplexND, - axis: int | tuple[int, ...] | None = None, -) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... +# +@overload +def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ... +@overload +def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... +@overload +def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ... +@overload +def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ... + +# +@overload +def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ... +@overload +def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... +@overload +def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ... +@overload +def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...