From 0be10746f9bf9ab28926dbbb68681f2d4eb09c1c Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 27 Mar 2024 13:30:47 +0000 Subject: [PATCH] Rudimentary tests for `SupportsIndex` in indexing methods --- array_api_tests/test_array_object.py | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index bc3e7276..a5541943 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -153,6 +153,41 @@ def test_setitem(shape, dtypes, data): ) +class AwkwardIndexable: + def __init__(self, value: int): + self._value = value + + def __int__(self): + raise TypeError("__int__() should not be called") + + def __index__(self): + return self._value + + +@pytest.mark.parametrize( + "x, key", + [ + (xp.asarray([0, 1]), AwkwardIndexable(1)), + (xp.asarray([[0, 1], [2, 3]]), (0, AwkwardIndexable(1))), + ] +) +def test_getitem_supports_index(x, key): + out = x[key] + assert out == xp.asarray(1) + + +@pytest.mark.parametrize( + "x, key, expected", + [ + (xp.asarray([0, 1]), AwkwardIndexable(1), xp.asarray([0, 42])), + (xp.asarray([[0, 1], [2, 3]]), (0, AwkwardIndexable(1)), xp.asarray([[0, 42], [2, 3]])), + ] +) +def test_setitem_supports_index(x, key, expected): + x[key] = xp.asarray(42) + ph.assert_array_elements("__setitem__", out=x, expected=expected, out_repr="x") + + @pytest.mark.unvectorized @pytest.mark.data_dependent_shapes @given(hh.shapes(), st.data())