Skip to content

Commit

Permalink
special: improve logsumexp, softmax and log_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Dec 31, 2024
1 parent 7c51738 commit a9a79f4
Showing 1 changed file with 67 additions and 19 deletions.
86 changes: 67 additions & 19 deletions scipy-stubs/special/_logsumexp.pyi
Original file line number Diff line number Diff line change
@@ -1,50 +1,98 @@
from typing import Literal, overload
from typing import overload

import numpy as np
import optype.numpy as onp
from scipy._typing import AnyShape, Falsy, Truthy

__all__ = ["log_softmax", "logsumexp", "softmax"]

# TODO: Support `return_sign=True`
@overload
def logsumexp(
a: onp.ToFloat,
axis: int | tuple[int, ...] | None = None,
axis: AnyShape | None = None,
b: onp.ToFloat | None = None,
keepdims: bool = False,
return_sign: Literal[False, 0] = False,
return_sign: Falsy = False,
) -> np.float64: ...
@overload
def logsumexp(
a: onp.ToComplex,
axis: int | tuple[int, ...] | None = None,
axis: AnyShape | None = None,
b: onp.ToFloat | None = None,
keepdims: bool = False,
return_sign: Literal[False, 0] = False,
return_sign: Falsy = False,
) -> np.float64 | np.complex128: ...
@overload
def logsumexp(
a: onp.ToFloatND,
axis: int | tuple[int, ...],
axis: AnyShape,
b: onp.ToFloat | onp.ToFloatND | None = None,
keepdims: bool = False,
return_sign: Literal[False, 0] = False,
return_sign: Falsy = False,
) -> np.float64 | onp.ArrayND[np.float64]: ...
@overload
def logsumexp(
a: onp.ToComplexND,
axis: int | tuple[int, ...],
axis: AnyShape,
b: onp.ToFloat | onp.ToFloatND | None = None,
keepdims: bool = False,
return_sign: Literal[False, 0] = False,
return_sign: Falsy = False,
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
@overload
def logsumexp(
a: onp.ToFloat,
axis: AnyShape | None = None,
b: onp.ToFloat | None = None,
keepdims: bool = False,
*,
return_sign: Truthy,
) -> tuple[np.float64, bool | np.bool_]: ...
@overload
def logsumexp(
a: onp.ToComplex,
axis: AnyShape | None = None,
b: onp.ToFloat | None = None,
keepdims: bool = False,
*,
return_sign: Truthy,
) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ...
@overload
def logsumexp(
a: onp.ToFloatND,
axis: AnyShape,
b: onp.ToFloat | onp.ToFloatND | None = None,
keepdims: bool = False,
*,
return_sign: Truthy,
) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ...
@overload
def logsumexp(
a: onp.ToComplexND,
axis: AnyShape,
b: onp.ToFloat | onp.ToFloatND | None = None,
keepdims: bool = False,
*,
return_sign: Truthy,
) -> (
tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]]
): ...

# TODO: Overload real/complex and scalar/array
def softmax(
x: onp.ToComplex | onp.ToComplexND,
axis: int | tuple[int, ...] | None = None,
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
def log_softmax(
x: onp.ToComplex | onp.ToComplexND,
axis: int | tuple[int, ...] | None = None,
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
#
@overload
def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
@overload
def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
@overload
def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
@overload
def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...

#
@overload
def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
@overload
def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
@overload
def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
@overload
def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...

0 comments on commit a9a79f4

Please sign in to comment.