From 8276649a5b5c2ea1ea317d2850b4fd375839ba02 Mon Sep 17 00:00:00 2001 From: jorenham Date: Thu, 19 Dec 2024 04:19:47 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=BD=EF=B8=8F=20`sparse`:=20accept=20`a?= =?UTF-8?q?xes:=20tuple[int[:],=20int[:]]`=20in=20`coo=5Farray.tensordot`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scipy-stubs/sparse/_coo.pyi | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/scipy-stubs/sparse/_coo.pyi b/scipy-stubs/sparse/_coo.pyi index 41dc8879..d23852e5 100644 --- a/scipy-stubs/sparse/_coo.pyi +++ b/scipy-stubs/sparse/_coo.pyi @@ -48,6 +48,8 @@ _SupComplex: TypeAlias = np.complex128 | np.clongdouble _SupFloat: TypeAlias = np.float64 | np.longdouble | _SupComplex _SupInt: TypeAlias = np.int_ | np.int64 | np.uint32 | np.uintp | np.uint | np.uint64 | _SupFloat +_Axes: TypeAlias = int | tuple[Sequence[int], Sequence[int]] + _SupComplexT = TypeVar("_SupComplexT", bound=_SupComplex) _SupFloatT = TypeVar("_SupFloatT", bound=_SupFloat) _SupIntT = TypeVar("_SupIntT", bound=_SupInt) @@ -307,23 +309,23 @@ class _coo_base(_data_matrix[_SCT, _ShapeT_co], _minmax_mixin[_SCT, _ShapeT_co], # would result in more overloads than that mypy has bugs (i.e. >1_200). # NOTE: due to a bug in `axes`, only `int` can be used at the moment (passing a 2-tuple or 2-list raises `TypeError`) @overload - def tensordot(self, /, other: _spbase[_SCT0], axes: int = 2) -> _SCT | _SCT0 | coo_array[_SCT | _SCT0]: ... + def tensordot(self, /, other: _spbase[_SCT0], axes: _Axes = 2) -> _SCT | _SCT0 | coo_array[_SCT | _SCT0]: ... @overload - def tensordot(self, /, other: _ToDense[_SCT0], axes: int = 2) -> _ScalarOrDense[_SCT | _SCT0]: ... + def tensordot(self, /, other: _ToDense[_SCT0], axes: _Axes = 2) -> _ScalarOrDense[_SCT | _SCT0]: ... @overload - def tensordot(self, /, other: onp.SequenceND[bool], axes: int = 2) -> _ScalarOrDense[_SCT]: ... + def tensordot(self, /, other: onp.SequenceND[bool], axes: _Axes = 2) -> _ScalarOrDense[_SCT]: ... @overload - def tensordot(self: _spbase[_SubInt], /, other: _JustND[int], axes: int = 2) -> _ScalarOrDense[np.int_]: ... + def tensordot(self: _spbase[_SubInt], /, other: _JustND[int], axes: _Axes = 2) -> _ScalarOrDense[np.int_]: ... @overload - def tensordot(self: _spbase[_SubFloat], /, other: _JustND[float], axes: int = 2) -> _ScalarOrDense[np.float64]: ... + def tensordot(self: _spbase[_SubFloat], /, other: _JustND[float], axes: _Axes = 2) -> _ScalarOrDense[np.float64]: ... @overload - def tensordot(self: _spbase[_SubComplex], /, other: _JustND[complex], axes: int = 2) -> _ScalarOrDense[np.complex128]: ... + def tensordot(self: _spbase[_SubComplex], /, other: _JustND[complex], axes: _Axes = 2) -> _ScalarOrDense[np.complex128]: ... @overload - def tensordot(self: _spbase[_SupComplexT], /, other: _JustND[complex], axes: int = 2) -> _ScalarOrDense[_SupComplexT]: ... + def tensordot(self: _spbase[_SupComplexT], /, other: _JustND[complex], axes: _Axes = 2) -> _ScalarOrDense[_SupComplexT]: ... @overload - def tensordot(self: _spbase[_SupFloatT], /, other: _JustND[float], axes: int = 2) -> _ScalarOrDense[_SupFloatT]: ... + def tensordot(self: _spbase[_SupFloatT], /, other: _JustND[float], axes: _Axes = 2) -> _ScalarOrDense[_SupFloatT]: ... @overload - def tensordot(self: _spbase[_SupIntT], /, other: _JustND[int], axes: int = 2) -> _ScalarOrDense[_SupIntT]: ... + def tensordot(self: _spbase[_SupIntT], /, other: _JustND[int], axes: _Axes = 2) -> _ScalarOrDense[_SupIntT]: ... class coo_array(_coo_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...