Skip to content

Commit

Permalink
spatial: generic [c]KDTree and other improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Jan 9, 2025
1 parent 99abc24 commit abe6970
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 87 deletions.
245 changes: 194 additions & 51 deletions scipy-stubs/spatial/_ckdtree.pyi
Original file line number Diff line number Diff line change
@@ -1,76 +1,119 @@
from typing import Literal as L, TypeAlias, overload, type_check_only
from typing import Generic, Literal as L, Protocol, TypeAlias, overload, type_check_only
from typing_extensions import TypeVar, override

import numpy as np
import optype as op
import optype.numpy as onp
from scipy._typing import Falsy, Truthy
from scipy.sparse import coo_matrix, dok_matrix

__all__ = ["cKDTree"]

_Weights: TypeAlias = onp.ToFloatND | tuple[onp.ToFloatND, onp.ToFloatND]
_Indices: TypeAlias = onp.Array1D[np.intp]
_Float1D: TypeAlias = onp.Array1D[np.float64]
_Float2D: TypeAlias = onp.Array2D[np.float64]

_NodeT_co = TypeVar("_NodeT_co", bound=_KDTreeNode | None, default=_KDTreeNode | None, covariant=True)
_BoxSizeT_co = TypeVar("_BoxSizeT_co", bound=_Float2D | None, default=_Float2D | None, covariant=True)
_BoxSizeDataT_co = TypeVar("_BoxSizeDataT_co", bound=_Float1D | None, default=_Float1D | None, covariant=True)

@type_check_only
class _CythonMixin:
def __setstate_cython__(self, pyx_state: object, /) -> None: ...
def __reduce_cython__(self, /) -> None: ...

class cKDTreeNode(_CythonMixin):
@property
def data_points(self, /) -> onp.ArrayND[np.float64]: ...
@property
def indices(self, /) -> onp.ArrayND[np.intp]: ...

# These are read-only attributes in cython, which behave like properties
# workaround for mypy's lack of cyclical TypeVar support
@type_check_only
class _KDTreeNode(Protocol):
@property
def level(self, /) -> int: ...
@property
def split_dim(self, /) -> int: ...
@property
def split(self, /) -> float: ...
@property
def children(self, /) -> int: ...
@property
def data_points(self, /) -> _Float2D: ...
@property
def indices(self, /) -> _Indices: ...
@property
def start_idx(self, /) -> int: ...
@property
def end_idx(self, /) -> int: ...
@property
def split(self, /) -> float: ...
@property
def lesser(self, /) -> cKDTreeNode | None: ...
def lesser(self, /) -> _KDTreeNode | None: ...
@property
def greater(self, /) -> cKDTreeNode | None: ...
def greater(self, /) -> _KDTreeNode | None: ...

###

class cKDTree(_CythonMixin):
class cKDTreeNode(_CythonMixin, _KDTreeNode, Generic[_NodeT_co]):
@property
def n(self, /) -> int: ...
@override
def lesser(self, /) -> _NodeT_co: ...
@property
def m(self, /) -> int: ...
@override
def greater(self, /) -> _NodeT_co: ...

class cKDTree(_CythonMixin, Generic[_BoxSizeT_co, _BoxSizeDataT_co]):
@property
def data(self, /) -> _Float2D: ...
@property
def leafsize(self, /) -> int: ...
@property
def size(self, /) -> int: ...
def m(self, /) -> int: ...
@property
def tree(self, /) -> cKDTreeNode: ...

# These are read-only attributes in cython, which behave like properties
def n(self, /) -> int: ...
@property
def maxes(self, /) -> _Float1D: ...
@property
def data(self, /) -> onp.ArrayND[np.float64]: ...
def mins(self, /) -> _Float1D: ...
@property
def maxes(self, /) -> onp.ArrayND[np.float64]: ...
def tree(self, /) -> cKDTreeNode: ...
@property
def mins(self, /) -> onp.ArrayND[np.float64]: ...
def size(self, /) -> int: ...
@property
def indices(self, /) -> onp.ArrayND[np.float64]: ...
def indices(self, /) -> _Indices: ...
@property
def boxsize(self, /) -> onp.ArrayND[np.float64] | None: ...
def boxsize(self, /) -> _BoxSizeT_co: ...
boxsize_data: _BoxSizeDataT_co

#
@overload
def __init__(
self,
self: cKDTree[None, None],
/,
data: onp.ToComplexND,
leafsize: int = ...,
compact_nodes: bool = ...,
copy_data: bool = ...,
balanced_tree: bool = ...,
boxsize: onp.ToFloat2D | None = ...,
data: onp.ToFloat2D,
leafsize: int = 16,
compact_nodes: bool = True,
copy_data: bool = False,
balanced_tree: bool = True,
boxsize: None = None,
) -> None: ...
@overload
def __init__(
self: cKDTree[_Float2D, _Float1D],
/,
data: onp.ToFloat2D,
leafsize: int,
compact_nodes: bool,
copy_data: bool,
balanced_tree: bool,
boxsize: onp.ToFloat2D,
) -> None: ...
@overload
def __init__(
self: cKDTree[_Float2D, _Float1D],
/,
data: onp.ToFloat2D,
leafsize: int = 16,
compact_nodes: bool = True,
copy_data: bool = False,
balanced_tree: bool = True,
*,
boxsize: onp.ToFloat2D,
) -> None: ...

#
Expand All @@ -79,24 +122,124 @@ class cKDTree(_CythonMixin):
/,
x: onp.ToFloat1D,
k: onp.ToInt | onp.ToInt1D = 1,
eps: onp.ToFloat = 0.0,
p: onp.ToFloat = 2.0,
distance_upper_bound: float = ..., # inf
workers: int | None = None,
eps: onp.ToFloat = ...,
p: onp.ToFloat = ...,
distance_upper_bound: float = float("inf"), # noqa: PYI011
workers: int | None = ...,
) -> tuple[float, np.intp] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.intp]]: ...

#
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatStrict1D,
r: onp.ToFloat,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = ...,
workers: op.CanIndex | None = None,
return_sorted: onp.ToBool | None = None,
return_length: Falsy = False,
) -> list[int]: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatStrict1D,
r: onp.ToFloat,
p: onp.ToFloat,
eps: onp.ToFloat,
workers: op.CanIndex | None,
return_sorted: onp.ToBool | None,
return_length: Truthy,
) -> np.intp: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatStrict1D,
r: onp.ToFloat,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = ...,
workers: op.CanIndex | None = None,
return_sorted: onp.ToBool | None = None,
*,
return_length: Truthy,
) -> np.intp: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatND,
r: onp.ToFloatND,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = ...,
workers: op.CanIndex | None = None,
return_sorted: onp.ToBool | None = None,
return_length: Falsy = False,
) -> onp.ArrayND[np.object_]: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatND,
r: onp.ToFloatND,
p: onp.ToFloat,
eps: onp.ToFloat,
workers: op.CanIndex | None,
return_sorted: onp.ToBool | None,
return_length: Truthy,
) -> onp.ArrayND[np.intp]: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatND,
r: onp.ToFloatND,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = ...,
workers: op.CanIndex | None = None,
return_sorted: onp.ToBool | None = None,
*,
return_length: Truthy,
) -> onp.ArrayND[np.intp]: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatND,
r: onp.ToFloat | onp.ToFloatND,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = 0.0,
workers: int | None = None,
return_sorted: bool | None = None,
return_length: bool = False,
eps: onp.ToFloat = ...,
workers: op.CanIndex | None = None,
return_sorted: onp.ToBool | None = None,
return_length: Falsy = False,
) -> list[int] | onp.ArrayND[np.object_]: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatND,
r: onp.ToFloat | onp.ToFloatND,
p: onp.ToFloat,
eps: onp.ToFloat,
workers: op.CanIndex | None,
return_sorted: onp.ToBool | None,
return_length: Truthy,
) -> np.intp | onp.ArrayND[np.intp]: ...
@overload
def query_ball_point(
self,
/,
x: onp.ToFloatND,
r: onp.ToFloat | onp.ToFloatND,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = ...,
workers: op.CanIndex | None = None,
return_sorted: onp.ToBool | None = None,
*,
return_length: Truthy,
) -> np.intp | onp.ArrayND[np.intp]: ...

#
def query_ball_tree(
Expand All @@ -105,7 +248,7 @@ class cKDTree(_CythonMixin):
other: cKDTree,
r: onp.ToFloat,
p: onp.ToFloat = 2.0,
eps: onp.ToFloat = 0.0,
eps: onp.ToFloat = ..., # defaults to `0.0`, but is overridden in `KDTree` with `0` as default
) -> list[list[int]]: ...

#
Expand Down Expand Up @@ -144,7 +287,7 @@ class cKDTree(_CythonMixin):
self,
/,
other: cKDTree,
r: onp.ToScalar,
r: onp.ToFloat,
p: onp.ToFloat = 2.0,
weights: tuple[None, None] | None = None,
cumulative: bool = True,
Expand All @@ -154,7 +297,7 @@ class cKDTree(_CythonMixin):
self,
/,
other: cKDTree,
r: onp.ToScalar,
r: onp.ToFloat,
p: onp.ToFloat,
weights: _Weights,
cumulative: bool = True,
Expand All @@ -164,7 +307,7 @@ class cKDTree(_CythonMixin):
self,
/,
other: cKDTree,
r: onp.ToScalar,
r: onp.ToFloat,
p: onp.ToFloat = 2.0,
*,
weights: _Weights,
Expand All @@ -175,32 +318,32 @@ class cKDTree(_CythonMixin):
self,
/,
other: cKDTree,
r: onp.ToFloat | onp.ToFloatND,
r: onp.ToFloat | onp.ToFloat1D,
p: onp.ToFloat = 2.0,
weights: tuple[None, None] | None = ...,
cumulative: bool = True,
) -> np.float64 | np.intp | onp.ArrayND[np.intp]: ...
) -> np.intp | onp.Array1D[np.intp]: ...
@overload
def count_neighbors(
self,
/,
other: cKDTree,
r: onp.ToFloat | onp.ToFloatND,
r: onp.ToFloat | onp.ToFloat1D,
p: onp.ToFloat,
weights: _Weights,
cumulative: bool = True,
) -> np.float64 | np.intp | onp.ArrayND[np.float64]: ...
) -> np.float64 | onp.Array1D[np.float64]: ...
@overload
def count_neighbors(
self,
/,
other: cKDTree,
r: onp.ToFloat | onp.ToFloatND,
r: onp.ToFloat | onp.ToFloat1D,
p: onp.ToFloat = 2.0,
*,
weights: _Weights,
cumulative: bool = True,
) -> np.float64 | np.intp | onp.ArrayND[np.float64]: ...
) -> np.float64 | onp.Array1D[np.float64]: ...

#
@overload
Expand All @@ -211,7 +354,7 @@ class cKDTree(_CythonMixin):
max_distance: onp.ToFloat,
p: onp.ToFloat = 2.0,
output_type: L["dok_matrix"] = ...,
) -> dok_matrix: ...
) -> dok_matrix[np.float64]: ...
@overload
def sparse_distance_matrix(
self,
Expand All @@ -221,7 +364,7 @@ class cKDTree(_CythonMixin):
p: onp.ToFloat = 2.0,
*,
output_type: L["coo_matrix"],
) -> coo_matrix: ...
) -> coo_matrix[np.float64]: ...
@overload
def sparse_distance_matrix(
self,
Expand Down
Loading

0 comments on commit abe6970

Please sign in to comment.