Skip to content

Commit

Permalink
👽️ sparse: accept axes: tuple[int[:], int[:]] in `coo_array.tenso…
Browse files Browse the repository at this point in the history
…rdot` (#357)
  • Loading branch information
jorenham authored Dec 19, 2024
2 parents d26fec6 + 8276649 commit 695c003
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions scipy-stubs/sparse/_coo.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ _SupComplex: TypeAlias = np.complex128 | np.clongdouble
_SupFloat: TypeAlias = np.float64 | np.longdouble | _SupComplex
_SupInt: TypeAlias = np.int_ | np.int64 | np.uint32 | np.uintp | np.uint | np.uint64 | _SupFloat

_Axes: TypeAlias = int | tuple[Sequence[int], Sequence[int]]

_SupComplexT = TypeVar("_SupComplexT", bound=_SupComplex)
_SupFloatT = TypeVar("_SupFloatT", bound=_SupFloat)
_SupIntT = TypeVar("_SupIntT", bound=_SupInt)
Expand Down Expand Up @@ -307,23 +309,23 @@ class _coo_base(_data_matrix[_SCT, _ShapeT_co], _minmax_mixin[_SCT, _ShapeT_co],
# would result in more overloads than that mypy has bugs (i.e. >1_200).
# NOTE: due to a bug in `axes`, only `int` can be used at the moment (passing a 2-tuple or 2-list raises `TypeError`)
@overload
def tensordot(self, /, other: _spbase[_SCT0], axes: int = 2) -> _SCT | _SCT0 | coo_array[_SCT | _SCT0]: ...
def tensordot(self, /, other: _spbase[_SCT0], axes: _Axes = 2) -> _SCT | _SCT0 | coo_array[_SCT | _SCT0]: ...
@overload
def tensordot(self, /, other: _ToDense[_SCT0], axes: int = 2) -> _ScalarOrDense[_SCT | _SCT0]: ...
def tensordot(self, /, other: _ToDense[_SCT0], axes: _Axes = 2) -> _ScalarOrDense[_SCT | _SCT0]: ...
@overload
def tensordot(self, /, other: onp.SequenceND[bool], axes: int = 2) -> _ScalarOrDense[_SCT]: ...
def tensordot(self, /, other: onp.SequenceND[bool], axes: _Axes = 2) -> _ScalarOrDense[_SCT]: ...
@overload
def tensordot(self: _spbase[_SubInt], /, other: _JustND[int], axes: int = 2) -> _ScalarOrDense[np.int_]: ...
def tensordot(self: _spbase[_SubInt], /, other: _JustND[int], axes: _Axes = 2) -> _ScalarOrDense[np.int_]: ...
@overload
def tensordot(self: _spbase[_SubFloat], /, other: _JustND[float], axes: int = 2) -> _ScalarOrDense[np.float64]: ...
def tensordot(self: _spbase[_SubFloat], /, other: _JustND[float], axes: _Axes = 2) -> _ScalarOrDense[np.float64]: ...
@overload
def tensordot(self: _spbase[_SubComplex], /, other: _JustND[complex], axes: int = 2) -> _ScalarOrDense[np.complex128]: ...
def tensordot(self: _spbase[_SubComplex], /, other: _JustND[complex], axes: _Axes = 2) -> _ScalarOrDense[np.complex128]: ...
@overload
def tensordot(self: _spbase[_SupComplexT], /, other: _JustND[complex], axes: int = 2) -> _ScalarOrDense[_SupComplexT]: ...
def tensordot(self: _spbase[_SupComplexT], /, other: _JustND[complex], axes: _Axes = 2) -> _ScalarOrDense[_SupComplexT]: ...
@overload
def tensordot(self: _spbase[_SupFloatT], /, other: _JustND[float], axes: int = 2) -> _ScalarOrDense[_SupFloatT]: ...
def tensordot(self: _spbase[_SupFloatT], /, other: _JustND[float], axes: _Axes = 2) -> _ScalarOrDense[_SupFloatT]: ...
@overload
def tensordot(self: _spbase[_SupIntT], /, other: _JustND[int], axes: int = 2) -> _ScalarOrDense[_SupIntT]: ...
def tensordot(self: _spbase[_SupIntT], /, other: _JustND[int], axes: _Axes = 2) -> _ScalarOrDense[_SupIntT]: ...

class coo_array(_coo_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...

Expand Down

0 comments on commit 695c003

Please sign in to comment.