Skip to content

Commit

Permalink
ENH: Dask: sort and argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 23, 2025
1 parent 8a79994 commit 1a7316f
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 23 deletions.
86 changes: 81 additions & 5 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from ...common import _aliases
from typing import Literal

from ...common import _aliases, array_namespace

from ..._internal import get_xp

Expand Down Expand Up @@ -228,10 +230,84 @@ def _isscalar(a):

return astype(da.minimum(da.maximum(x, min), max), x.dtype)

# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']

_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
def sort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
"""
Array API compatibility layer around the lack of sort() in Dask.
Warnings
--------
This function temporarily rechunks the array along `axis` to a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
return _sort_argsort("sort", x, axis=axis, descending=descending, stable=stable)


def argsort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
"""
Array API compatibility layer around the lack of argsort() in Dask.
See the corresponding documentation in the array library and/or the array API
specification for more details.
Warnings
--------
This function temporarily rechunks the array along `axis` into a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
"""
return _sort_argsort("argsort", x, axis=axis, descending=descending, stable=stable)


def _sort_argsort(
func: Literal["sort", "argsort"],
x: Array,
/,
*,
axis: int,
descending: bool,
stable: bool,
) -> Array:
"""
Implementation of sort() and argsort()
TODO Implement sort and argsort properly in Dask on top of the shuffle subsystem.
"""
if axis < 0:
axis += x.ndim
rechunk = False
if x.numblocks[axis] > 1:
rechunk = True
# Break chunks on other axes in an attempt to keep chunk size low
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
meta_xp = array_namespace(x._meta)
x = da.map_blocks(
getattr(meta_xp, func),
x,
axis=axis,
descending=descending,
stable=stable,
dtype=x.dtype,
meta=x._meta,
)
if rechunk:
# rather than reconstructing the original chunks, which can be a
# very expensive affair, just break down oversized chunks without
# incurring in any transfers over the network.
# This has the downside of a risk of overchunking if the array is
# then used in operations against other arrays that match the
# original chunking pattern.
x = x.rechunk()
return x


_common_aliases = _aliases.__all__

__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
Expand All @@ -242,4 +318,4 @@ def _isscalar(a):
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']

_all_ignore = ["get_xp", "da", "np"]
_all_ignore = ["Literal", "array_namespace", "get_xp", "da", "np"]
22 changes: 5 additions & 17 deletions dask-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking
# Various indexing errors
array_api_tests/test_array_object.py::test_getitem_masking

# asarray(copy=False) is not yet implemented
# copied from numpy xfails, TODO: should this pass with dask?
array_api_tests/test_creation_functions.py::test_asarray_arrays

# zero division error, and typeerror: tuple indices must be integers or slices not tuple
array_api_tests/test_creation_functions.py::test_eye

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# out[-1]=dask.aray<getitem ...> but should be some floating number
# out[-1]=dask.array<getitem ...> but should be some floating number
# (I think the test is not forcing the op to be computed?)
array_api_tests/test_creation_functions.py::test_linspace

Expand All @@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]

# No sorting in dask
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_sorting_functions.py::test_sort
array_api_tests/test_signatures.py::test_func_signature[argsort]
array_api_tests/test_signatures.py::test_func_signature[sort]

# Array methods and attributes not already on np.ndarray cannot be wrapped
# Array methods and attributes not already on da.Array cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
Expand All @@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values
# fails for ndim > 2
array_api_tests/test_linalg.py::test_svdvals
array_api_tests/test_linalg.py::test_cholesky

# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
array_api_tests/test_linalg.py::test_tensordot

Expand Down Expand Up @@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross
array_api_tests/test_linalg.py::test_det
array_api_tests/test_linalg.py::test_eigh
array_api_tests/test_linalg.py::test_eigvalsh
array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_rank
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_slogdet
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
Expand All @@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]

array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_rank

# missing mode kw
# https://github.com/dask/dask/issues/10388
array_api_tests/test_linalg.py::test_qr
Expand Down
73 changes: 72 additions & 1 deletion tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

import array_api_strict
import dask
import numpy as np
import pytest
Expand All @@ -20,9 +21,10 @@ def assert_no_compute():
Context manager that raises if at any point inside it anything calls compute()
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""

def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")

with dask.config.set(scheduler=get):
yield

Expand All @@ -40,6 +42,7 @@ def test_assert_no_compute():

# Test no_compute for functions that use generic _aliases with xp=np


def test_unary_ops_no_compute(xp):
with assert_no_compute():
a = xp.asarray([1.5, -1.5])
Expand All @@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp):

# Test no_compute for functions that are fully bespoke for dask


def test_asarray_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
Expand Down Expand Up @@ -88,6 +92,14 @@ def test_clip_no_compute(xp):
xp.clip(a, 1, 8)


@pytest.mark.parametrize("chunks", (5, 10))
def test_sort_argsort_nocompute(xp, chunks):
with assert_no_compute():
a = xp.arange(10, chunks=chunks)
xp.sort(a)
xp.argsort(a)


def test_generators_are_lazy(xp):
"""
Test that generator functions are fully lazy, e.g. that
Expand All @@ -106,3 +118,62 @@ def test_generators_are_lazy(xp):
xp.ones_like(a)
xp.empty_like(a)
xp.full_like(a, fill_value=123)


@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunks(xp, func, axis):
"""Test that sort and argsort are functionally correct when
the array is chunked along the sort axis, e.g. the sort is
not just local to each chunk.
"""
a = da.random.random((10, 10), chunks=(5, 5))
actual = getattr(xp, func)(a, axis=axis)
expect = getattr(np, func)(a.compute(), axis=axis)
np.testing.assert_array_equal(actual, expect)


@pytest.mark.parametrize(
"shape,chunks",
[
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Sort chunks can be 128 MiB each; no need for final rechunk.
((20_000, 20_000), "auto"),
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
((2, 2**30 * 3 // 16), "auto"),
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
# Surely the user must know what they're doing, so don't
# perform the final rechunk.
((2, 2**30 * 3 // 16), (1, -1)),
],
)
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
"""
Test that sort and argsort produce reasonably-sized chunks
in the output array, even if they had to go through a singular
huge one to perform the operation.
"""
a = da.random.random(shape, chunks=chunks)
b = getattr(xp, func)(a)
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
assert (
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
or b.chunks == a.chunks
)


@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_meta(xp, func):
"""Test meta-namespace other than numpy"""
typ = type(array_api_strict.asarray(0))
a = da.random.random(10)
b = a.map_blocks(array_api_strict.asarray)
assert isinstance(b._meta, typ)
c = getattr(xp, func)(b)
assert isinstance(c._meta, typ)
d = c.compute()
# Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
assert isinstance(d, typ)
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))

0 comments on commit 1a7316f

Please sign in to comment.